/*!
A simple implementation of congruence closure, [`CongruenceClosure`], which drives modifications to an arbitrary union-find ADT implementing the [`UnionFind`] trait.
Based on Nieuwenhuis and Oliveras [Proof-producing Congruence Closure](https://www.cs.upc.edu/~oliveras/rta05.pdf).

A minimal [disjoint set forest](https://en.wikipedia.org/wiki/Disjoint-set_data_structure) implementation with dense `usize` nodes, [`DisjointSetForest`], is also provided.
*/
#![forbid(missing_docs, missing_debug_implementations)]
use core::{
    cmp::Ordering,
    fmt::{self, Debug, Formatter},
    hash::{BuildHasher, Hash, Hasher},
};
use hashbrown::{raw::RawTable, HashSet};
use std::{collections::hash_map::RandomState, borrow::Borrow};

mod language;
pub use language::*;

/// A simple implementation of congruence closure, parametric over an arbitrary disjoint set forest implementation.
///
/// Based on Nieuwenhuis and Oliveras [Proof-producing Congruence Closure](https://www.cs.upc.edu/~oliveras/rta05.pdf).
///
/// # Examples
/// ```rust
/// # use congruence::{DisjointSetForest, CongruenceClosure, UnionFind, CongruenceState};
/// let mut dsf = DisjointSetForest::with_capacity(5);
/// let mut cc = CongruenceClosure::new();
/// let mut st = CongruenceState::default();
/// let a = 0;
/// let b = 1;
/// let c = 2;
/// let d = 3;
/// let e = 4;
/// let f = 5;
/// // Set a * b = c
/// cc.equation((a, b), c, &mut dsf, &mut st);
/// assert!(dsf.is_empty());
/// // Set d * e = f
/// cc.equation((d, e), f, &mut dsf, &mut st);
/// assert!(dsf.is_empty());
/// // Set a = d
/// cc.merge(a, d, &mut dsf, &mut st);
/// assert!(!dsf.is_empty());
/// for i in 0..5 {
///     for j in 0..5 {
///         assert_eq!(dsf.node_eq(i, j), i == j || (i == a || i == d) && (j == a || j == d))  
///     }
/// }
/// // Set b = e
/// cc.merge(b, e, &mut dsf, &mut st);
/// assert!(dsf.node_eq(b, e));
/// // By congruence, we now have that c = a * b = d * e = f
/// assert!(dsf.node_eq(c, f));
/// assert!(dsf.node_eq(a, d));
/// assert!(!dsf.node_eq(a, b));
/// assert!(!dsf.node_eq(a, c));
/// assert!(!dsf.node_eq(a, e));
/// assert!(!dsf.node_eq(a, f));
/// assert!(!dsf.node_eq(b, c));
/// assert!(!dsf.node_eq(b, d));
/// assert!(!dsf.node_eq(b, f));
/// assert!(!dsf.node_eq(c, d));
/// assert!(!dsf.node_eq(c, e));
/// assert!(!dsf.node_eq(d, e));
/// assert!(!dsf.node_eq(d, f));
/// assert!(!dsf.node_eq(e, f));
/// ```
#[derive(Clone)]
pub struct CongruenceClosure<L, I = usize, S = hashbrown::hash_map::DefaultHashBuilder> {
    /// The use-lists: for each representative `a`, a list of input equations `b_1(b_2) = b` where, for some `i`, `a ~ b_i`
    use_lists: RawTable<(I, Vec<L>)>,
    /// A lookup table mapping pairs of representatives `(b, c)` to terms `a` where `b(c) ~ a` iff such an equation exists.
    /// We have the invariant that `use_lists[b].contains(a)` and `use_lists[c].contains(a)` iff `lookup[(b, c)] = a`.
    lookup: RawTable<(L, I)>,
    /// The shared hasher for `use_lists` and `lookup`
    hasher: S,
}

impl<L: Debug, I: Debug, S> Debug for CongruenceClosure<L, I, S> {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        //TODO: this
        let mut map = f.debug_map();
        unsafe {
            for lookup in self.lookup.iter() {
                let (key, repr) = lookup.as_ref();
                map.entry(key, repr);
            }
        }
        map.finish()
    }
}

impl<L, I, S: Default + BuildHasher> Default for CongruenceClosure<L, I, S> {
    fn default() -> Self {
        Self {
            use_lists: Default::default(),
            lookup: Default::default(),
            hasher: Default::default(),
        }
    }
}

impl<L, I> CongruenceClosure<L, I>
where
    L: Clone + Eq,
    I: Copy + Eq,
{
    /// Create a new, empty congruence closure
    /// 
    /// # Example
    /// ```rust
    /// # use congruence::CongruenceClosure;
    /// let cc = CongruenceClosure::<(usize, usize)>::new();
    /// assert!(cc.is_empty());
    /// ```
    #[inline]
    pub fn new() -> Self {
        CongruenceClosure {
            use_lists: RawTable::new(),
            lookup: Default::default(),
            hasher: Default::default(),
        }
    }

    /// Create a new, empty congruence closure with the given node and pair capacities
    /// 
    /// # Example
    /// ```rust
    /// # use congruence::CongruenceClosure;
    /// let cc = CongruenceClosure::<(usize, usize)>::with_capacity(5, 5);
    /// assert!(cc.is_empty());
    /// ```
    #[inline]
    pub fn with_capacity(nodes: usize, pairs: usize) -> Self {
        CongruenceClosure {
            use_lists: RawTable::with_capacity(nodes),
            lookup: RawTable::with_capacity(pairs),
            hasher: Default::default(),
        }
    }
}

impl<L, I, S> CongruenceClosure<L, I, S>
where
    L: Clone,
    I: Hash + Copy + Eq,
    S: BuildHasher,
{
    /// Create a new, empty congruence closure with the given hasher
    #[inline]
    pub fn with_hasher(hasher: S) -> Self
    where
        S: Clone,
    {
        CongruenceClosure {
            use_lists: RawTable::new(),
            lookup: RawTable::new(),
            hasher,
        }
    }

    /// Create a new, empty congruence closure with the given node capacity, pair capacity, and hasher
    #[inline]
    pub fn with_capacity_and_hasher(nodes: usize, pairs: usize, hasher: S) -> Self
    where
        S: Clone,
    {
        CongruenceClosure {
            use_lists: RawTable::with_capacity(nodes),
            lookup: RawTable::with_capacity(pairs),
            hasher,
        }
    }

    /// Whether this congruence closure is empty, i.e. contains no *congruence* relations
    /// 
    /// # Example
    /// ```rust
    /// # use congruence::{CongruenceClosure, DisjointSetForest, CongruenceState};
    /// let mut cc = CongruenceClosure::<(usize, usize)>::new();
    /// let mut dsf = DisjointSetForest::new();
    /// let mut cs = CongruenceState::new();
    /// assert!(cc.is_empty());
    /// 
    /// // This operation only adds something to the disjoint set forest, leaving the congruence
    /// // closure itself empty
    /// cc.merge(3, 4, &mut dsf, &mut cs);
    /// assert!(cc.is_empty());
    /// assert!(!dsf.is_empty());
    /// 
    /// // We now actually add something to the congruence closure by equating an *expression*, here
    /// // simply a pair of `usize` indices, to a term
    /// cc.equation((1, 2), 3, &mut dsf, &mut cs);
    /// assert!(!cc.is_empty());
    /// assert!(!dsf.is_empty());
    /// 
    /// // On the other hand, equating an expression to a term without introducing any new term-level
    /// // equalities only adds things to the congruence closure, leaving the union-find structure
    /// // unchanged:
    /// cc.clear();
    /// dsf.clear();
    /// cc.equation((1, 2), 3, &mut dsf, &mut cs);
    /// assert!(!cc.is_empty());
    /// assert!(dsf.is_empty());
    /// ```
    #[inline]
    pub fn is_empty(&self) -> bool {
        self.use_lists.len() == 0 && self.lookup.len() == 0
    }

    /// Clear this congruence closure, maintaining it's capacity but otherwise resetting it to empty
    /// 
    /// # Example
    /// ```rust
    /// # use congruence::{CongruenceClosure, DisjointSetForest, CongruenceState};
    /// let mut cc = CongruenceClosure::<(usize, usize)>::new();
    /// let mut dsf = DisjointSetForest::new();
    /// let mut cs = CongruenceState::new();
    /// assert!(cc.is_empty());
    /// cc.equation((1, 2), 3, &mut dsf, &mut cs);
    /// assert!(!cc.is_empty());
    /// cc.clear();
    /// assert!(cc.is_empty());
    /// ```
    #[inline]
    pub fn clear(&mut self) {
        self.use_lists.clear();
        self.lookup.clear();
    }

    /// Register an equation of the form `expr = result`, where `expr` is an expression in this congruence closure's language
    /// 
    /// # Example
    /// ```rust
    /// # use congruence::{CongruenceClosure, DisjointSetForest, CongruenceState, UnionFind};
    /// let mut cc = CongruenceClosure::<(usize, usize)>::new();
    /// let mut dsf = DisjointSetForest::new();
    /// let mut cs = CongruenceState::new();
    /// // Say 0 ~ 1 and 2 ~ 3
    /// cc.merge(0, 1, &mut dsf, &mut cs);
    /// cc.merge(2, 3, &mut dsf, &mut cs);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(dsf.node_eq(2, 3));
    /// assert!(!dsf.node_eq(4, 5));
    /// 
    /// // If we set (0, 2) ~ 4 and (1, 3) ~ 5, we can then deduce 4 ~ 5
    /// cc.equation((0, 2), 4, &mut dsf, &mut cs);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(dsf.node_eq(2, 3));
    /// assert!(!dsf.node_eq(4, 5));
    /// cc.equation((1, 3), 5, &mut dsf, &mut cs);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(dsf.node_eq(2, 3));
    /// assert!(dsf.node_eq(4, 5));
    /// ```
    pub fn equation<U>(
        &mut self,
        expr: L,
        result: I,
        union_find: &mut U,
        state: &mut CongruenceState<I>,
    ) where
        L: Language<U, I>,
        U: UnionFind<I>,
    {
        let hash = expr.hash_one_mod_mut(&self.hasher, &mut |uf, ix| uf.find_mut(ix), union_find);
        if let Some(&(_, b)) = self.lookup.get(hash, |(k, _r)| {
            expr.eq_mod_mut(k, &mut |uf, i, j| uf.node_eq_mut(i, j), union_find)
        }) {
            self.merge(result, b, union_find, state)
        } else {
            let r: Result<(), ()> = expr.visit_deps_mut(
                &mut |uf, ix| Ok(self.push_user(uf.find(ix), expr.clone())),
                union_find,
            );
            debug_assert_eq!(r, Ok(()));
            self.lookup.insert(hash, (expr, result), |(k, _r)| {
                k.hash_one_mod(&self.hasher, &mut |uf, ix| uf.find(ix), union_find)
            });
        }
    }

    /// Get a representative for the congruence class of `expr` in this context, if one exists.
    /// 
    /// # Example
    /// ```rust
    /// # use congruence::{CongruenceClosure, DisjointSetForest, CongruenceState};
    /// let mut cc = CongruenceClosure::<(usize, usize)>::new();
    /// let mut dsf = DisjointSetForest::new();
    /// let mut cs = CongruenceState::new();
    /// assert_eq!(cc.lookup(&(1, 2), &dsf), None);
    /// assert_eq!(cc.lookup(&(0, 2), &dsf), None);
    /// assert_eq!(cc.lookup(&(0, 1), &dsf), None);
    /// cc.equation((1, 2), 3, &mut dsf, &mut cs);
    /// assert_eq!(cc.lookup(&(1, 2), &dsf), Some(&((1, 2), 3)));
    /// assert_eq!(cc.lookup(&(0, 2), &dsf), None);
    /// assert_eq!(cc.lookup(&(0, 1), &dsf), None);
    /// cc.merge(0, 1, &mut dsf, &mut cs);
    /// let repr = *cc.lookup(&(1, 2), &dsf).unwrap();
    /// assert_eq!(cc.lookup(&(0, 2), &dsf), Some(&repr));
    /// assert!(repr == ((1, 2), 3) || repr == ((0, 2), 3));
    /// assert_eq!(cc.lookup(&(0, 1), &dsf), None);
    /// ```
    pub fn lookup<B, U>(&self, expr: &B, union_find: &U) -> Option<&(L, I)>
    where
        B: Language<U, I>,
        L: Borrow<B>,
        U: UnionFind<I>,
    {
        let hash = expr.hash_one_mod(&self.hasher, &mut |uf, ix| uf.find(ix), union_find);
        self.lookup.get(hash, |(k, _r)| {
            expr.eq_mod(k.borrow(), &mut |uf, i, j| uf.node_eq(i, j), union_find)
        })
    }

    /// Get a representative for the congruence class of `expr` in this context, if one exists.
    /// 
    /// Always returns the same result as [`Self::lookup`], but may optimize the union-find data structure provided in the process.
    pub fn lookup_mut<B, U>(&self, expr: &B, union_find: &mut U) -> Option<&(L, I)>
    where
        B: Language<U, I>,
        L: Borrow<B>,
        U: UnionFind<I>,
    {
        let hash = expr.hash_one_mod_mut(&self.hasher, &mut |uf, ix| uf.find_mut(ix), union_find);
        self.lookup.get(hash, |(k, _r)| {
            expr.eq_mod_mut(k.borrow(), &mut |uf, i, j| uf.node_eq_mut(i, j), union_find)
        })
    }

    /// Check whether an expressions in the equation language is congruent to a given index
    /// 
    /// # Example
    /// ```rust
    /// # use congruence::{CongruenceClosure, DisjointSetForest, CongruenceState};
    /// let mut cc = CongruenceClosure::<(usize, usize)>::new();
    /// let mut dsf = DisjointSetForest::new();
    /// let mut cs = CongruenceState::new();
    /// 
    /// assert!(!cc.expr_cong(&(0, 1), 4, &dsf));
    /// assert!(!cc.expr_cong(&(2, 3), 4, &dsf));
    /// 
    /// cc.merge(0, 2, &mut dsf, &mut cs);
    /// cc.merge(1, 3, &mut dsf, &mut cs);
    /// assert!(!cc.expr_cong(&(0, 1), 4, &dsf));
    /// assert!(!cc.expr_cong(&(2, 3), 4, &dsf));
    /// 
    /// // If we set `(0, 1) ~ 4`, we can deduce `(2, 3) ~ 4`, even though we've never inserted the pair `(2, 3)`
    /// cc.equation((0, 1), 4, &mut dsf, &mut cs);
    /// assert!(cc.expr_cong(&(0, 1), 4, &dsf));
    /// assert!(cc.expr_cong(&(2, 3), 4, &dsf));
    /// ```
    pub fn expr_cong<B, U>(&self, expr: &B, ix: I, union_find: &U) -> bool
    where
        B: Language<U, I>,
        L: Borrow<B>,
        U: UnionFind<I>,
    {
        if let Some((_, repr)) = self.lookup(expr, union_find) {
                return union_find.node_eq(*repr, ix);
        }
        false
    }

    /// Check whether an expressions in the equation language is congruent to a given index
    /// 
    /// Always returns the same result as [`Self::expr_cong`], but may optimize the union-find data structure provided in the process.
    pub fn expr_cong_mut<B, U>(&self, expr: &B, ix: I, union_find: &mut U) -> bool
    where
        B: Language<U, I>,
        L: Borrow<B>,
        U: UnionFind<I>,
    {
        if let Some((_, repr)) = self.lookup_mut(expr, union_find) {
                return union_find.node_eq_mut(*repr, ix);
        }
        false
    }

    /// Merge the equivalence classes of two nodes
    /// 
    /// # Example    
    /// ```rust
    /// # use congruence::{CongruenceClosure, DisjointSetForest, CongruenceState, UnionFind};
    /// let mut cc = CongruenceClosure::<(usize, usize)>::new();
    /// let mut dsf = DisjointSetForest::new();
    /// let mut raw_dsf = DisjointSetForest::new();
    /// let mut cs = CongruenceState::new();
    /// 
    /// // Calls to `cc.merge` without any equations added to the congruence context are equivalent to calls to `dsf.union`
    /// cc.merge(0, 1, &mut dsf, &mut cs);
    /// raw_dsf.union(0, 1);
    /// assert_eq!(dsf, raw_dsf);
    /// 
    /// cc.merge(2, 3, &mut dsf, &mut cs);
    /// raw_dsf.union(2, 3);
    /// assert_eq!(dsf, raw_dsf);
    /// 
    /// // Similarly, with only one equation in the context, they do not change the union-find ADT,
    /// // though the equation itself may be modified as the union-find is updated:
    /// assert_eq!(cc.lookup(&(4, 5), &dsf), None);
    /// cc.equation((4, 5), 6, &mut dsf, &mut cs);
    /// assert_eq!(dsf, raw_dsf);
    /// assert_eq!(cc.lookup(&(4, 5), &dsf), Some(&((4, 5), 6)));
    /// 
    /// cc.merge(0, 4, &mut dsf, &mut cs);
    /// raw_dsf.union(0, 4);
    /// assert_eq!(dsf, raw_dsf);
    /// 
    /// // Once we have a set of equations which overlap, however, calling `merge` can trigger multiple `unions` in the DSF,
    /// // implementing congruence closure:
    /// cc.equation((4, 7), 8, &mut dsf, &mut cs);
    /// assert_eq!(dsf, raw_dsf);
    /// 
    /// cc.merge(7, 5, &mut dsf, &mut cs);
    /// raw_dsf.union(7, 5);
    /// assert!(raw_dsf.refines(&dsf));
    /// assert!(!dsf.refines(&raw_dsf));
    /// // In particular:
    /// assert!(!raw_dsf.node_eq(6, 8));
    /// assert!(dsf.node_eq(6, 8));
    /// ```
    #[inline]
    pub fn merge<U>(
        &mut self,
        mut a: I,
        mut b: I,
        union_find: &mut U,
        state: &mut CongruenceState<I>,
    ) where
        L: Language<U, I>,
        U: UnionFind<I>,
    {
        loop {
            let a_repr = union_find.find_mut(a);
            let b_repr = union_find.find_mut(b);
            if a_repr != b_repr {
                let new_repr = union_find.pre_union_find(a_repr, b_repr);
                let old_repr = if new_repr == a_repr {
                    b_repr
                } else {
                    debug_assert!(new_repr == b_repr);
                    a_repr
                };
                debug_assert!(old_repr != new_repr);

                let list = self.remove_use_list(old_repr);
                for old_key in list.iter().flatten() {
                    let curr_hash = old_key.hash_one_mod_mut(
                        &self.hasher,
                        &mut |uf, ix| uf.find_mut(ix),
                        union_find,
                    );
                    let mut has_new_repr = false;
                    let new_hash = old_key.hash_one_mod_mut(
                        &self.hasher,
                        &mut |uf, ix| {
                            let found = uf.find_mut(ix);
                            has_new_repr |= found == new_repr;
                            if found == old_repr {
                                new_repr
                            } else {
                                found
                            }
                        },
                        union_find,
                    );
                    if let Some((_curr_key, c)) = self.lookup.remove_entry(curr_hash, |(k, _r)| {
                        old_key.eq_mod_mut(k, &mut |uf, i, j| uf.node_eq_mut(i, j), union_find)
                    }) {
                        if let Some((_, d)) = self.lookup.get(new_hash, |(k, _r)| {
                            old_key.eq_mod_mut(
                                k,
                                &mut |uf, i, j| {
                                    if i == j {
                                        true
                                    } else {
                                        let i = uf.find(i);
                                        let j = uf.find(j);
                                        i == j
                                            || (i == old_repr || i == new_repr)
                                                && (j == old_repr || j == new_repr)
                                    }
                                },
                                union_find,
                            )
                        }) {
                            state.pending.push((c, *d));
                        } else {
                            //TODO: use insert_no_grow?
                            self.lookup
                                .insert(new_hash, (old_key.clone(), c), |(k, _r)| {
                                    k.hash_one_mod(
                                        &self.hasher,
                                        &mut |uf, i| {
                                            let found = uf.find(i);
                                            if found == old_repr {
                                                new_repr
                                            } else {
                                                found
                                            }
                                        },
                                        union_find,
                                    )
                                });
                            if !has_new_repr {
                                self.push_user(new_repr, old_key.clone());
                            }
                        }
                    }
                }

                union_find.union(a_repr, b_repr);
                debug_assert!(union_find.find_mut(a) == new_repr);
                debug_assert!(union_find.find_mut(b) == new_repr);
            }

            if let Some(pending) = state.pending.pop() {
                a = pending.0;
                b = pending.1;
            } else {
                return;
            }
        }
    }

    /// Check this data structure's invariants w.r.t a union-find ADT
    pub fn check_invariants<U>(&self, union_find: &mut U) -> bool
    where
        L: Language<U, I>,
        U: UnionFind<I>,
    {
        let mut seen = HashSet::with_capacity_and_hasher(self.lookup.len(), RandomState::default());
        unsafe {
            for use_list in self.use_lists.iter() {
                let (used, users) = use_list.as_ref();
                if union_find.find_mut(*used) != *used {
                    return false;
                }
                for user in users {
                    let hash =
                        user.hash_one_mod_mut(&self.hasher, &mut |uf, ix| uf.find(ix), union_find);
                    if let Some(lookup) = self.lookup.get(hash, |(k, _r)| {
                        k.eq_mod_mut(user, &mut |uf, i, j| uf.node_eq_mut(i, j), union_find)
                    }) {
                        seen.insert(lookup as *const _);
                    } else {
                        return false;
                    }
                }
            }
            /*
            for lookup in self.lookup.iter() {
                if !seen.contains(&(lookup.as_ptr() as *const _)) {
                    return false;
                }
            }
            */
        }
        true
    }

    #[inline(always)]
    fn remove_use_list(&mut self, used: I) -> Option<Vec<L>> {
        Some(
            self.use_lists
                .remove_entry(hash_one(&self.hasher, &used), |(k, _v)| *k == used)?
                .1,
        )
    }

    #[inline(always)]
    fn push_user(&mut self, used: I, user: L) {
        let hash = hash_one(&self.hasher, &used);
        if let Some((k, list)) = self.use_lists.get_mut(hash, |(k, _v)| *k == used) {
            debug_assert!(*k == used);
            list.push(user)
        } else {
            self.use_lists
                .insert_entry(hash, (used, Vec::with_capacity(1)), |(k, _v)| {
                    hash_one(&self.hasher, k)
                })
                .1
                .push(user);
        }
    }
}

/// State for implementing congruence closure
#[derive(Clone)]
pub struct CongruenceState<I = usize> {
    /// Pending equalities
    pending: Vec<(I, I)>,
}

impl<I> CongruenceState<I> {
    /// Creates a new, null congruence state. Guaranteed not to allocate
    pub const fn new() -> CongruenceState<I> {
        CongruenceState {
            pending: Vec::new(),
        }
    }

    /// Creates a new congruence state with the given capacity.
    pub fn with_capacity(cap: usize) -> CongruenceState<I> {
        CongruenceState {
            pending: Vec::with_capacity(cap),
        }
    }
}

impl<I> Debug for CongruenceState<I>
where
    I: Debug,
{
    #[inline]
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_struct("CongruenceState")
            .field("pending", &self.pending)
            .field("capacity", &self.pending.capacity())
            .finish()
    }
}

impl<I> Default for CongruenceState<I> {
    fn default() -> Self {
        Self {
            pending: Default::default(),
        }
    }
}

/// A trait implemented by data-structures implementing the union-find ADT.
///
/// See [`DisjointSetForest`] for an example implementation and usage
pub trait UnionFind<I = usize> {
    /// Take the union of two nodes, returning the new root.
    ///
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// # let mut dsf = DisjointSetForest::with_capacity(3);
    /// # let a = 0;
    /// # let b = 1;
    /// let union = dsf.union_find(a, b);
    /// ```
    /// is equivalent to
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// # let mut dsf = DisjointSetForest::with_capacity(3);
    /// # let a = 0;
    /// # let b = 1;
    /// dsf.union(a, b);
    /// let union = dsf.find(a);
    /// assert_eq!(union, dsf.find(b));
    /// ```
    /// but may be optimized better.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert!(!dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// let union = dsf.union_find(0, 1);
    /// assert!(union == 0 || union == 1);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// ```
    fn union_find(&mut self, left: I, right: I) -> I;

    /// Get what *would* be the new root if two nodes were unioned.
    ///
    /// That is, the following behaviour is guaranteed
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// # let mut dsf = DisjointSetForest::with_capacity(3);
    /// # let a = 0;
    /// # let b = 1;
    /// # let c = 2;
    /// # let z = 3;
    /// let union = dsf.pre_union_find(a, b);
    /// // Note that, at this point, it is *not* guaranteed that `dsf.node_eq(a, b)`!
    /// dsf.find(c);
    /// // ...
    /// dsf.find(z);
    /// assert_eq!(dsf.union_find(a, b), union);
    /// ```
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// # let mut dsf = DisjointSetForest::with_capacity(3);
    /// dsf.union(0, 1);
    /// let union = dsf.find(0);
    /// assert_eq!(union, dsf.find(1));
    /// ```
    /// but may be optimized better.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert!(!dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// let union = dsf.union_find(0, 1);
    /// assert!(union == 0 || union == 1);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// ```
    fn pre_union_find(&mut self, left: I, right: I) -> I;

    /// Merge the equivalence classes of two nodes
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert!(!dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// dsf.union(0, 1);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// ```
    fn union(&mut self, left: I, right: I);

    /// Find the representative of a node, without requiring a mutable borrow of this data structure.
    ///
    /// Always returns the same result as [`Self::find`], but may not optimize the underlying data structure.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert_eq!(dsf.find(0), 0);
    /// assert_eq!(dsf.find(1), 1);
    /// dsf.union(0, 1);
    /// assert_eq!(dsf.find(0), dsf.find(1));
    /// ```
    fn find(&self, node: I) -> I;

    /// Check whether two nodes are in the same equivalence class, without requiring a mutable borrow of this data structure.
    ///
    /// Always returns the same result as [`Self::node_eq_mut`], but may not optimize the underlying data structure.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(10);
    /// dsf.union(3, 5);
    /// dsf.union(6, 3);
    /// dsf.union(7, 2);
    /// for i in 0..10 {
    ///     for j in 0..10 {
    ///         assert_eq!(dsf.node_eq_mut(i, j), dsf.node_eq(i, j))
    ///     }
    /// }
    /// ```
    fn node_eq(&self, left: I, right: I) -> bool;

    /// Find the representative of a node, while potentially optimizing the underlying data structure.
    ///
    /// Always returns the same result as [`Self::find`], but may optimize the underlying data structure.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert_eq!(dsf.find_mut(0), 0);
    /// assert_eq!(dsf.find_mut(1), 1);
    /// dsf.union(0, 1);
    /// assert_eq!(dsf.find_mut(0), dsf.find_mut(1));
    /// ```
    #[inline(always)]
    fn find_mut(&mut self, node: I) -> I {
        self.find(node)
    }

    /// Check whether two nodes are in the same equivalence class, without requiring a mutable borrow of this data structure.
    ///
    /// Always returns the same result as [`Self::node_eq`], but may optimize the underlying data structure.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(10);
    /// dsf.union(3, 5);
    /// dsf.union(6, 3);
    /// dsf.union(7, 2);
    /// for i in 0..10 {
    ///     for j in 0..10 {
    ///         assert_eq!(dsf.node_eq_mut(i, j), dsf.node_eq(i, j))
    ///     }
    /// }
    /// ```
    #[inline(always)]
    fn node_eq_mut(&mut self, left: I, right: I) -> bool {
        self.node_eq(left, right)
    }
}

/// A minimal [disjoint set forest](https://en.wikipedia.org/wiki/Disjoint-set_data_structure) with dense `usize` nodes, implementing the [`UnionFind`] trait.
///
/// Has support for up to `2**(64 - 8)` nodes, which should be enough to exhaust the entire address space of modern CPUs.
///
/// # Examples
/// ```rust
/// # use congruence::{DisjointSetForest, UnionFind};
/// let mut dsf = DisjointSetForest::with_capacity(4);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert_eq!(dsf.node_eq(i, j), i == j)
///     }
/// }
/// dsf.union(0, 1);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert_eq!(
///             dsf.node_eq(i, j),
///             i == j || (i == 1 && j == 0) || (i == 0 && j == 1)
///         )
///     }
/// }
/// dsf.union(2, 3);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert_eq!(
///             dsf.node_eq(i, j),
///             i == j
///                 || (i == 1 && j == 0)
///                 || (i == 0 && j == 1)
///                 || (i == 2 && j == 3)
///                 || (i == 3 && j == 2)
///         )
///     }
/// }
/// dsf.union(1, 2);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert!(dsf.node_eq(i, j),)
///     }
/// }
/// ```
#[derive(Debug, Default, Clone)]
pub struct DisjointSetForest {
    nodes: Vec<Node>,
}

impl DisjointSetForest {
    /// Create a new, empty disjoint set forest. Guaranteed not to allocate.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// assert_eq!(dsf.capacity(), 0);
    /// ```
    #[inline(always)]
    pub const fn new() -> DisjointSetForest {
        DisjointSetForest { nodes: Vec::new() }
    }

    /// Create a new disjoint set forest with `n` nodes
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let dsf = DisjointSetForest::with_capacity(5);
    /// assert!(dsf.is_empty());
    /// assert!(dsf.capacity() >= 5);
    /// ```
    #[inline]
    pub fn with_capacity(n: usize) -> DisjointSetForest {
        DisjointSetForest {
            nodes: Vec::with_capacity(n),
        }
    }

    /// Reserve capacity for `n` additional nodes
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let mut dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// assert_eq!(dsf.capacity(), 0);
    /// dsf.reserve(5);
    /// assert!(dsf.is_empty());
    /// assert!(dsf.capacity() >= 5);
    /// ```
    #[inline(always)]
    pub fn reserve(&mut self, n: usize) {
        self.nodes.reserve(n);
    }

    /// Get whether this forest is empty, i.e. contains no *relations* between elements
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let mut dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// ```
    #[inline(always)]
    pub fn is_empty(&self) -> bool {
        for (ix, node) in self.nodes.iter().enumerate() {
            if node.parent() != ix {
                return false;
            }
        }
        true
    }

    /// Get the current capacity of this forest
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let mut dsf = DisjointSetForest::new();
    /// assert_eq!(dsf.capacity(), 0);
    /// dsf.reserve(1);
    /// assert!(dsf.capacity() >= 1);
    /// dsf.reserve(5);
    /// assert!(dsf.capacity() >= 6);
    /// ```
    #[inline(always)]
    pub fn capacity(&self) -> usize {
        self.nodes.capacity()
    }

    /// Clear this forest, removing all nodes but preserving it's capacity
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// assert!(!dsf.node_eq(5, 3));
    /// assert_eq!(dsf.capacity(), 0);
    /// dsf.union(5, 3);
    /// assert!(!dsf.is_empty());
    /// assert!(dsf.node_eq(5, 3));
    /// let capacity = dsf.capacity();
    /// assert!(capacity >= 5);
    /// dsf.clear();
    /// assert!(dsf.is_empty());
    /// assert!(!dsf.node_eq(5, 3));
    /// assert_eq!(dsf.capacity(), capacity);
    /// ```
    #[inline(always)]
    pub fn clear(&mut self) {
        self.nodes.clear()
    }

    /// Insert a new node
    fn insert(&mut self) {
        self.nodes.push(Node::new(self.nodes.len()))
    }

    /// Get whether this disjoint set forest refines another, i.e., if `a ~ b` in `self`, then `a ~ b` in `other`
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf_0 = DisjointSetForest::with_capacity(5);
    /// let mut dsf_1 = DisjointSetForest::with_capacity(5);
    /// assert!(dsf_0.refines_mut(&mut dsf_1));
    /// assert!(dsf_0.refines_mut(&mut dsf_1));
    /// dsf_0.union(3, 4);
    /// dsf_0.union(3, 2);
    /// assert!(!dsf_0.refines_mut(&mut dsf_1));
    /// assert!(dsf_1.refines_mut(&mut dsf_0));
    /// dsf_1.union(3, 4);
    /// assert!(!dsf_0.refines_mut(&mut dsf_1));
    /// assert!(dsf_1.refines_mut(&mut dsf_0));
    /// dsf_1.union(3, 5);
    /// assert!(!dsf_0.refines_mut(&mut dsf_1));
    /// assert!(!dsf_1.refines_mut(&mut dsf_0));
    /// dsf_1.union(3, 2);
    /// assert!(dsf_0.refines_mut(&mut dsf_1));
    /// assert!(!dsf_1.refines_mut(&mut dsf_0));
    /// dsf_0.union(3, 5);
    /// assert!(dsf_0.refines_mut(&mut dsf_1));
    /// assert!(dsf_1.refines_mut(&mut dsf_0));
    /// ```
    pub fn refines_mut(&mut self, other: &mut impl UnionFind) -> bool {
        for ix in 0..self.nodes.len() {
            if !other.node_eq_mut(ix, self.find_mut(ix)) {
                return false;
            }
        }
        true
    }

    /// Get whether this disjoint set forest refines another, i.e., if `a ~ b` in `self`, then `a ~ b` in `other`, without performing path compression
    ///
    /// Always returns the same result as [Self::`refines`], but does not optimize the data structure, leading to worse amortized performance.    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf_0 = DisjointSetForest::with_capacity(5);
    /// let mut dsf_1 = DisjointSetForest::with_capacity(5);
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// dsf_0.union(3, 4);
    /// dsf_0.union(3, 2);
    /// assert!(!dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_1.refines(&mut dsf_0));
    /// dsf_1.union(3, 4);
    /// assert!(!dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_1.refines(&mut dsf_0));
    /// dsf_1.union(3, 5);
    /// assert!(!dsf_0.refines(&mut dsf_1));
    /// assert!(!dsf_1.refines(&mut dsf_0));
    /// dsf_1.union(3, 2);
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// assert!(!dsf_1.refines(&mut dsf_0));
    /// dsf_0.union(3, 5);
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_1.refines(&mut dsf_0));
    /// ```
    pub fn refines(&self, other: &DisjointSetForest) -> bool {
        for ix in 0..self.nodes.len() {
            if !other.node_eq(ix, self.find(ix)) {
                return false;
            }
        }
        true
    }

    /// Take the union of two representative nodes, returning the new root
    fn union_repr(&mut self, left: usize, right: usize) -> usize {
        let (parent, child, bump) = match self.nodes[left].rank().cmp(&self.nodes[right].rank()) {
            Ordering::Greater => (left, right, 0),
            Ordering::Equal => (left, right, 1),
            Ordering::Less => (right, left, 0),
        };
        self.nodes[child].set_parent(parent);
        self.nodes[parent].bump_rank(bump);
        parent
    }
}

impl PartialEq for DisjointSetForest {
    fn eq(&self, other: &Self) -> bool {
        self.refines(other) && other.refines(self)
    }
}

impl Eq for DisjointSetForest {}

impl UnionFind for DisjointSetForest {
    #[inline]
    fn union_find(&mut self, left: usize, right: usize) -> usize {
        if left >= self.nodes.len() || right >= self.nodes.len() {
            let max = left.max(right);
            self.nodes.reserve(1 + max - self.nodes.len());
            while max >= self.nodes.len() {
                self.insert()
            }
        }
        let left = self.find_mut(left);
        let right = self.find_mut(right);
        self.union_repr(left, right)
    }

    #[inline]
    fn pre_union_find(&mut self, left: usize, right: usize) -> usize {
        let left_repr = self.find_mut(left);
        if left == right {
            return left_repr;
        }
        let right_repr = self.find_mut(right);
        if self.nodes.get(left_repr).unwrap_or(&Node::new(0)).rank()
            >= self.nodes.get(right_repr).unwrap_or(&Node::new(0)).rank()
        {
            left_repr
        } else {
            right_repr
        }
    }

    #[inline]
    fn union(&mut self, left: usize, right: usize) {
        self.union_find(left, right);
    }

    #[inline]
    fn find(&self, mut ix: usize) -> usize {
        if ix >= self.nodes.len() {
            return ix;
        }
        while self.nodes[ix].parent() != ix {
            ix = self.nodes[ix].parent()
        }
        ix
    }

    #[inline]
    fn node_eq(&self, i: usize, j: usize) -> bool {
        if i == j {
            true
        } else {
            self.find(i) == self.find(j)
        }
    }

    #[inline]
    fn find_mut(&mut self, mut node: usize) -> usize {
        if node >= self.nodes.len() {
            return node;
        }
        loop {
            let parent = self.nodes[node].parent();
            if parent == node {
                return node;
            }
            let grandparent = self.nodes[parent].parent();
            self.nodes[node].set_parent(grandparent);
            node = parent;
        }
    }

    #[inline]
    fn node_eq_mut(&mut self, left: usize, right: usize) -> bool {
        if left == right {
            true
        } else {
            self.find_mut(left) == self.find_mut(right)
        }
    }
}

/// A node in a disjoint set forest. The first 8 bits are the node's rank, and the remaining bits are the node's parent
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
struct Node(u64);

impl Debug for Node {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_tuple("Node")
            .field(&self.parent())
            .field(&self.rank())
            .finish()
    }
}

impl Node {
    const RANK_SHIFT: u32 = u64::BITS - 8;
    const IX_MASK: u64 = (0b1 << Self::RANK_SHIFT) - 1;
    const RANK_MASK: u64 = !Self::IX_MASK;

    fn new(parent: usize) -> Node {
        assert_eq!(
            parent as u64 & !Self::IX_MASK,
            0,
            "Overflowed maximum node capacity!"
        );
        Node(parent as u64)
    }

    fn set_parent(&mut self, ix: usize) {
        debug_assert_eq!(ix as u64 & !Self::IX_MASK, 0, "Set invalid parent {}", ix);
        self.0 = (self.0 & !Self::IX_MASK) | (ix as u64)
    }

    fn parent(&self) -> usize {
        (self.0 & Self::IX_MASK) as usize
    }

    fn rank(&self) -> u8 {
        ((self.0 & Self::RANK_MASK) >> Self::RANK_SHIFT) as u8
    }

    fn bump_rank(&mut self, bump: u8) {
        self.0 = (self.0 & !Self::RANK_MASK)
            | (self.rank().saturating_add(bump) as u64) << Self::RANK_SHIFT;
    }
}

#[inline(always)]
fn hash_one(hasher: &impl BuildHasher, hashee: &impl Hash) -> u64 {
    let mut hasher = hasher.build_hasher();
    hashee.hash(&mut hasher);
    hasher.finish()
}

#[cfg(test)]
mod test {
    use super::*;
    use rand::prelude::SliceRandom;
    use rand::Rng;

    #[test]
    fn congruence_chain() {
        let mut dsf = DisjointSetForest::with_capacity(10);
        let mut cc = CongruenceClosure::new();
        let mut cs = CongruenceState::new();

        let (a, b, c, d, e, f, g, h) = (0, 1, 2, 3, 4, 5, 6, 7);

        // Set a * b ~ c
        cc.equation((a, b), c, &mut dsf, &mut cs);
        // Set a * d ~ e
        cc.equation((a, d), e, &mut dsf, &mut cs);
        // Set f * c ~ g
        cc.equation((f, c), g, &mut dsf, &mut cs);
        // Set f * e ~ h
        cc.equation((f, e), h, &mut dsf, &mut cs);

        assert!(dsf.is_empty());

        // If b ~ d, then a * b ~ c ~ a * d ~ e, and therefore f * c ~ g ~ f * e ~ h
        cc.merge(b, d, &mut dsf, &mut cs);
        assert!(!dsf.is_empty());
        assert!(dsf.node_eq_mut(b, d));
        assert!(dsf.node_eq_mut(c, e));
        assert!(dsf.node_eq_mut(g, h));
        assert!(!dsf.node_eq_mut(a, b));
        assert!(!dsf.node_eq_mut(a, c));
        assert!(!dsf.node_eq_mut(a, g));
        assert!(!dsf.node_eq_mut(b, c));
        assert!(!dsf.node_eq_mut(b, g));
        assert!(!dsf.node_eq_mut(c, g));
    }

    #[test]
    fn invariants() {
        let mut dsf = DisjointSetForest::with_capacity(3);
        let mut cc = CongruenceClosure::new();
        let mut st = CongruenceState::default();

        cc.merge(0, 1, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
        cc.merge(0, 2, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
        cc.equation((0, 1), 0, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
        cc.equation((2, 3), 4, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
        cc.equation((0, 1), 0, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
        cc.merge(3, 5, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
        cc.merge(5, 0, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
        cc.equation((3, 5), 5, &mut dsf, &mut st);
        debug_assert!(cc.check_invariants(&mut dsf));
    }

    #[test]
    fn congruence_closure_order_and_join() {
        let mut rng = rand::thread_rng();
        let mut dsf = DisjointSetForest::with_capacity(3);
        let mut cc_dsf = DisjointSetForest::with_capacity(30);
        let mut cc_end_dsf = DisjointSetForest::with_capacity(64);
        let mut cc = CongruenceClosure::new();
        let mut cc_end = CongruenceClosure::new();
        let mut st = CongruenceState::default();
        let mut equations = Vec::with_capacity(4096);

        for _ in 0..2048 {
            let a = rng.gen::<usize>() % 256;
            let b = rng.gen::<usize>() % 256;
            dsf.union(a, b);
            cc_end.merge(a, b, &mut cc_end_dsf, &mut st);
            cc.merge(a, b, &mut cc_dsf, &mut st);
            let c = rng.gen::<usize>() % 256;
            let d = rng.gen::<usize>() % 256;
            dsf.union(c, d);
            cc_end.merge(c, d, &mut cc_end_dsf, &mut st);
            cc.merge(c, d, &mut cc_dsf, &mut st);
            let e = rng.gen::<usize>() % 256;
            let f = rng.gen::<usize>() % 256;
            cc.equation((a, b), c, &mut cc_dsf, &mut st);
            cc.equation((d, e), f, &mut cc_dsf, &mut st);
            equations.push((a, b, c, d, e, f));
            let e = e + 256;
            let f = f + 256;
            cc.equation((a, b), c, &mut cc_dsf, &mut st);
            cc.equation((d, e), f, &mut cc_dsf, &mut st);
            equations.push((a, b, c, d, e, f));
        }
        assert!(dsf.refines_mut(&mut cc_dsf));
        assert!(!cc_dsf.refines_mut(&mut dsf));
        assert_eq!(dsf, cc_end_dsf);

        equations.shuffle(&mut rng);
        for (a, b, c, d, e, f) in equations {
            cc_end.equation((a, b), c, &mut cc_end_dsf, &mut st);
            cc_end.equation((d, e), f, &mut cc_end_dsf, &mut st);
        }
        assert_ne!(dsf, cc_end_dsf);
        assert_eq!(cc_dsf, cc_end_dsf);
        assert_ne!(dsf, cc_end_dsf);
    }
}
