//! Contains [`DisjointSets`](crate::disjoint_sets::DisjointSets) which is the main implementation
//! of the disjoint-set forest in order to support the union-find algorithm.
//! [`DisjointSets`](crate::disjoint_sets::DisjointSets) uses union-by-rank and path-compression
//! to improve the algorithmic complexity.
//!
//! See [`DisjointSets`](crate::disjoint_sets::DisjointSets) for more details.

use crate::{
    node::Node,
    traits::{Error, Result, UnionFind},
};
use std::{
    cmp::Ordering::*,
    collections::{HashMap, HashSet},
    hash::Hash,
};

/// An implementation of a disjoint-set forest that supports the union-find algorithm.
///
/// This implementation uses _union-by-rank_ as well as _path-compression_ heuristics to
/// improve the algorithmic complexity.
///
/// See the trait [`UnionFind`](crate::traits::UnionFind), which this struct implements, to
/// see the main operations that this struct supports.
///
/// # Algorithmic Complexity
///
/// As mentioned before, this implementation uses _union-by-rank_ as well as _path-compression_
/// heuristics to improve the algorithmic complexity.
///
/// As noted in the [Wikipedia article](https://en.wikipedia.org/wiki/Disjoint-set_data_structure),
/// this implies that a sequence of `m` calls to `find_set`, `make_set` and `union` requires `O(mα(n))`
/// time where `α(n)` is the [inverse Ackermann
/// function](https://en.wikipedia.org/wiki/Inverse_Ackermann_function). For all practical
/// purposes, you can think of each operation as being essentially `O(1)`.
///
/// # References
///
/// This struct was implemented using ideas from the following:
///
/// 1. [Introduction to Algorithms](https://en.wikipedia.org/wiki/Introduction_to_Algorithms)
/// 1. [Disjoint-set data structure](https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
pub struct DisjointSets<T>
where
    T: Copy + Eq + Hash,
{
    nodes: HashMap<T, Node<T>>,
}

//
impl<T> UnionFind<T> for DisjointSets<T>
where
    T: Copy + Eq + Hash,
{
    /// See [`UnionFind`](crate::traits::UnionFind) for an example.
    fn make_set(&mut self, item: T) -> Result<()> {
        //
        match self.contains(item) {
            //
            true => Err(Error::ItemAlreadyExists),
            //
            false => {
                //
                let node = Node::new(item, 0);
                self.nodes.insert(item, node);
                //
                Ok(())
            }
        }
    }

    /// See [`UnionFind`](crate::traits::UnionFind) for an example.
    fn union(&self, x: T, y: T) -> Result<()> {
        //
        match x == y {
            //
            true => Ok(()),
            //
            false => self.link(
                self.find_set(x)?, //
                self.find_set(y)?, //
            ),
        }
    }

    /// See [`UnionFind`](crate::traits::UnionFind) for an example.
    fn find_set(&self, item: T) -> Result<T> {
        //
        match self.get_node(&item) {
            //
            Some(node) => {
                //
                debug_assert!(item == node.get_item());

                //
                match node.get_parent() {
                    // `node` is not a root node, so
                    // 1. do path compression
                    // 2. set parent to the root node
                    Some(parent) => {
                        let root = self.find_set(parent)?;
                        node.set_parent(root);
                        Ok(root)
                    }
                    // `node` is a root node, so return its item directly
                    None => Ok(item),
                }
            }
            // `item` does not exist in this set
            None => Err(Error::ItemDoesNotExist),
        }
    }
}

//
impl<T> DisjointSets<T>
where
    T: Copy + Eq + Hash,
{
    /// Creates an empty instance of `DisjointSets`.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::new();
    /// ```
    pub fn new() -> Self {
        Self {
            nodes: HashMap::new(),
        }
    }

    /// Creates an empty instance of `DisjointSets` with the specified capacity.
    ///
    /// The instance returned by this method will be able to hold at least `capacity` elements without reallocating.
    /// If capacity is 0, the hash map will not allocate.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::with_capacity(64);
    /// ```
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            nodes: HashMap::with_capacity(capacity),
        }
    }

    /// Returns the number of elements in this instance of `DisjointSets`.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::new();
    ///
    /// sets.make_set(9).unwrap();
    ///
    /// assert_eq!(sets.len(), 1);
    /// ```
    pub fn len(&self) -> usize {
        self.nodes.len()
    }

    /// Returns `true` if this instance of `DisjointSets` contains `item`, `false` otherwise.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::new();
    ///
    /// sets.make_set(9).unwrap();
    ///
    /// assert_eq!(sets.contains(9), true);
    /// assert_eq!(sets.contains(7), false);
    /// ```
    pub fn contains(&self, item: T) -> bool {
        self.nodes.contains_key(&item)
    }

    /// Clears this `DisjointSets`, removing all data. Keeps the allocated memory for reuse.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::new();
    ///
    /// sets.make_set(9).unwrap();
    /// assert_eq!(sets.len(), 1);
    ///
    /// sets.clear();
    /// assert_eq!(sets.len(), 0);
    /// ```
    pub fn clear(&mut self) {
        self.nodes.clear()
    }

    /// Reserves capacity for at least additional more elements to be inserted in this `DisjointSets`.
    /// May reserve more space to avoid frequent reallocations.
    ///
    /// # Panics
    ///
    /// Panics if the new allocation size overflows usize.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::new();
    ///
    /// sets.reserve(16);
    /// ```
    pub fn reserve(&mut self, additional: usize) {
        self.nodes.reserve(additional)
    }

    /// Shrinks the capacity of this `DisjointSets` as much as possible.
    /// It will drop down as much as possible while maintaining the internal
    /// rules and possibly leaving some space in accordance with the resize policy.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::with_capacity(100);
    ///
    /// sets.make_set(4).unwrap();
    /// sets.make_set(9).unwrap();
    /// assert!(100 <= sets.capacity());
    ///
    /// sets.shrink_to_fit();
    /// assert!(2 <= sets.capacity());
    /// ```
    pub fn shrink_to_fit(&mut self) {
        self.nodes.shrink_to_fit()
    }

    /// Returns the number of elements the map can hold without reallocating.
    /// This number is a lower bound; the HashMap<K, V> might be able to hold more,
    /// but is guaranteed to be able to hold at least this many.
    ///
    /// # Example
    ///
    /// ```
    /// use union_find_rs::prelude::*;
    ///
    /// let mut sets: DisjointSets<usize> = DisjointSets::with_capacity(100);
    ///
    /// assert!(100 <= sets.capacity());
    /// ```
    pub fn capacity(&self) -> usize {
        self.nodes.capacity()
    }

    //
    fn get_node(&self, item: &T) -> Option<&Node<T>> {
        self.nodes.get(item)
    }

    // link the two sets for which `x` and `y` are representatives
    fn link(&self, x: T, y: T) -> Result<()> {
        // assert that `x` and `y` are indeed root nodes
        debug_assert!(self.get_node(&x).unwrap().get_parent().is_none());
        debug_assert!(self.get_node(&y).unwrap().get_parent().is_none());
        //
        match self.nodes.contains_key(&x) && self.nodes.contains_key(&y) {
            //
            true => {
                //
                let x_node = self.get_node(&x).unwrap();
                let y_node = self.get_node(&y).unwrap();
                //
                let x_rank = x_node.get_rank();
                let y_rank = y_node.get_rank();
                //
                match x_rank.cmp(&y_rank) {
                    // node with the larger rank becomes the parent
                    Greater => y_node.set_parent(x),
                    Less => x_node.set_parent(y),
                    // choose the parent arbitrarily, and increment the parent's rank
                    Equal => {
                        x_node.set_parent(y);
                        y_node.set_rank(y_rank + 1);
                    }
                }
                //
                Ok(())
            }
            //
            false => Err(Error::ItemDoesNotExist),
        }
    }
}

// and we'll implement IntoIterator
impl<T> IntoIterator for DisjointSets<T>
where
    T: Copy + Eq + Hash,
{
    type Item = HashSet<T>;
    type IntoIter = std::vec::IntoIter<Self::Item>;

    fn into_iter(self) -> Self::IntoIter {
        // mapping from representatives to the items in their sets
        let rep_to_items: HashMap<T, HashSet<T>> =
            self.nodes
                .iter()
                .fold(HashMap::with_capacity(self.len()), |mut acc, (k, _)| {
                    let parent = self.find_set(*k).unwrap();
                    if !acc.contains_key(&parent) {
                        acc.insert(parent, HashSet::new());
                    }

                    acc.get_mut(&parent).unwrap().insert(*k);
                    acc
                });

        // just collect the sets
        rep_to_items
            .into_iter()
            .map(|(_, v)| v)
            .collect::<Vec<_>>()
            .into_iter()
    }
}

//
#[cfg(test)]
mod tests {
    use crate::prelude::*;
    use std::{
        collections::{hash_map::DefaultHasher, HashSet},
        hash::{Hash, Hasher},
    };

    // create a
    macro_rules! test_with_different_types {
        //
        ( $mod_name:ident, $test_fn_name:ident ) => {
            //
            #[cfg(test)]
            mod $mod_name {
                use super::*;

                //
                const ITEM_COUNT: usize = 4096;

                //
                #[test]
                fn test_with_usize() {
                    //
                    let items: Vec<usize> = (0..ITEM_COUNT).collect();
                    //
                    $test_fn_name(items);
                }

                //
                #[test]
                fn test_with_i64() {
                    //
                    const ITEM_COUNT_I64: i64 = ITEM_COUNT as i64;
                    let items: Vec<i64> = (-ITEM_COUNT_I64..ITEM_COUNT_I64).collect();
                    //
                    $test_fn_name(items);
                }

                //
                #[test]
                fn test_with_copyable_structs() {
                    //
                    #[derive(Copy, Clone, PartialEq, Eq, Hash)]
                    struct CustomStruct {
                        id: usize,
                    }
                    //
                    let items: Vec<CustomStruct> =
                        (0..ITEM_COUNT).map(|i| CustomStruct { id: i }).collect();
                    //
                    $test_fn_name(items);
                }

                //
                #[test]
                fn test_noncopyable() {
                    //
                    #[derive(PartialEq, Eq, Hash)]
                    struct CustomStruct {
                        id: usize,
                    }
                    //
                    let items: Vec<CustomStruct> =
                        (0..ITEM_COUNT).map(|i| CustomStruct { id: i }).collect();
                    //
                    let refs: Vec<&CustomStruct> = items.iter().collect();
                    //
                    $test_fn_name(refs);
                }
            }
        };
    }

    //
    fn path_compression_tester<T>(items: Vec<T>)
    where
        T: Copy + Eq + Hash,
    {
        //
        let mut union_find = DisjointSets::new();
        //
        items
            .iter()
            .copied()
            .for_each(|item| union_find.make_set(item).unwrap());

        //
        items.iter().enumerate().skip(1).for_each(|(i, item)| {
            let prev_item: T = *items.get(i - 1).unwrap();
            union_find.union(*item, prev_item).unwrap();
        });

        // check that all items have the same representative
        let reps: HashSet<_> = items
            .iter()
            .copied()
            .map(|item| union_find.find_set(item).unwrap())
            .collect();
        //
        assert_eq!(reps.len(), 1);
    }

    //
    fn make_set_only_tester<T>(items: Vec<T>)
    where
        T: Copy + Eq + Hash,
    {
        //
        let mut union_find = DisjointSets::new();
        //
        items
            .iter()
            .copied()
            .for_each(|item| union_find.make_set(item).unwrap());

        // check that all items have the same representative
        let reps: HashSet<_> = items
            .iter()
            .copied()
            .map(|item| union_find.find_set(item).unwrap())
            .collect();
        //
        assert_eq!(reps.len(), items.len());
    }

    //
    fn union_all_tester<T>(items: Vec<T>)
    where
        T: Copy + Eq + Hash,
    {
        //
        let mut union_find = DisjointSets::new();
        //
        items
            .iter()
            .copied()
            .for_each(|item| union_find.make_set(item).unwrap());

        //
        items
            .iter()
            .copied()
            .enumerate()
            .skip(1)
            .for_each(|(i, item)| {
                union_find.union(item, *items.get(i - 1).unwrap()).unwrap();
            });

        // check that there is only 1 set
        let sets: Vec<HashSet<T>> = union_find.into_iter().collect();
        assert_eq!(sets.len(), 1);
    }

    //
    fn hash<T>(t: &T) -> u64
    where
        T: Hash,
    {
        let mut s = DefaultHasher::new();
        t.hash(&mut s);
        s.finish()
    }
    //
    fn union_by_hash_parity_tester<T>(items: Vec<T>)
    where
        T: Copy + Eq + Hash,
    {
        //
        let mut union_find = DisjointSets::new();
        //
        items
            .iter()
            .copied()
            .for_each(|item| union_find.make_set(item).unwrap());

        let mut hash_mod_even = None;
        let mut hash_mod_odd = None;

        //
        items.iter().copied().for_each(|item| {
            if hash(&item) % 2 == 0 {
                match hash_mod_even {
                    Some(x) => union_find.union(item, x).unwrap(),
                    None => {
                        hash_mod_even = Some(item);
                    }
                }
            } else {
                match hash_mod_odd {
                    Some(x) => union_find.union(item, x).unwrap(),
                    None => {
                        hash_mod_odd = Some(item);
                    }
                }
            }
        });

        // check that all items have the same representative
        let sets: Vec<HashSet<T>> = union_find.into_iter().collect();
        //
        assert_eq!(sets.len(), 2);
        assert!(sets.iter().any(|set| set.iter().all(|x| hash(x) % 2 == 0)));
        assert!(sets.iter().any(|set| set.iter().all(|x| hash(x) % 2 == 1)));
    }

    test_with_different_types!(make_set_only, make_set_only_tester);
    test_with_different_types!(union_by_hash_parity, union_by_hash_parity_tester);
    test_with_different_types!(path_compression, path_compression_tester);
    test_with_different_types!(union_all, union_all_tester);
}
