use crate::{Graph, Vertex};
use crate::error::GraphError;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::dsu::DSU;

pub struct EdgeSpanningTreeEdge<'a, T> {
    pub from : &'a Vertex<T>,
    pub to : &'a Vertex<T>,
    pub weight: T
}

struct KruskalEdge<U> where U : PartialOrd + Copy {
    from: usize,
    to: usize,
    dist: U,
}

impl <U> std::cmp::PartialEq for KruskalEdge<U> where U : PartialOrd + Copy {
    fn eq(&self, other: &KruskalEdge<U>) -> bool {
        self.dist == other.dist
    }
}

impl <U> Eq for KruskalEdge<U> where U : PartialOrd + Copy {}

impl <U> std::cmp::Ord for KruskalEdge<U> where U : PartialOrd + Copy {
    fn cmp(&self, other: &Self) -> Ordering {
        other.dist.partial_cmp(&self.dist).unwrap()
    }
}

impl <U> std::cmp::PartialOrd for KruskalEdge<U> where U : PartialOrd + Copy {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(other.dist.partial_cmp(&self.dist).unwrap())
    }
}

/// Kruskal's algorithm is an efficient algorithm for constructing the minimum spanning tree of a weighted connected undirected graph.
/// Algorithmic complexity - O(|E| * log(|E|)), where |E| is the number of edges in the graph.
///
/// ```
/// use luka::{Graph, algorithms};
///
/// let mut graph = Graph::new(20);
///
/// graph.add_edge(1, 2, 7.0).unwrap();
/// graph.add_edge(2, 1, 7.0).unwrap();
/// graph.add_edge(1, 4, 5.0).unwrap();
/// graph.add_edge(4, 1, 5.0).unwrap();
/// graph.add_edge(2, 3, 8.0).unwrap();
/// graph.add_edge(3, 2, 8.0).unwrap();
/// graph.add_edge(2, 5, 7.0).unwrap();
/// graph.add_edge(5, 2, 7.0).unwrap();
/// graph.add_edge(2, 4, 9.0).unwrap();
/// graph.add_edge(4, 2, 9.0).unwrap();
/// graph.add_edge(3, 5, 5.0).unwrap();
/// graph.add_edge(5, 3, 5.0).unwrap();
/// graph.add_edge(5, 7, 9.0).unwrap();
/// graph.add_edge(7, 5, 9.0).unwrap();
/// graph.add_edge(5, 6, 8.0).unwrap();
/// graph.add_edge(6, 5, 8.0).unwrap();
/// graph.add_edge(5, 4, 15.0).unwrap();
/// graph.add_edge(4, 5, 15.0).unwrap();
/// graph.add_edge(6, 7, 11.0).unwrap();
/// graph.add_edge(7, 6, 11.0).unwrap();
/// graph.add_edge(6, 4, 6.0).unwrap();
/// graph.add_edge(4, 6, 6.0).unwrap();
///
/// let edges = algorithms::kruskal(&graph).unwrap();
///
/// let summary_weight: f64 = edges.iter().map(|value| value.weight).sum();
/// assert_eq!(39.0, summary_weight);
/// let res = edges.iter().map(|value| (value.from.id(), value.to.id())).collect::<Vec<(usize, usize)>>();
/// assert_eq!(res, vec![(1, 4), (5, 3), (4, 6), (2, 1), (5, 2), (7, 5)]);
/// ```

pub fn kruskal<T>(graph: &Graph<T>) -> Result<Vec<EdgeSpanningTreeEdge<T>>, GraphError> where T: PartialOrd + Copy {
    let mut edges = vec![];
    let mut heap = BinaryHeap::new();
    let mut dsu = DSU::new(graph.size());
    for (from, to) in graph.adj.iter().enumerate().skip(1) {
        dsu.make_set(from).unwrap();
        for edge in to.edges.iter() {
            heap.push(KruskalEdge {
                from,
                to: edge.to,
                dist: edge.weight
            });
        }
    }

    while let Some (value) = heap.pop() {
        if dsu.find_set(value.from) != dsu.find_set(value.to) {
            dsu.union_sets(value.from, value.to).unwrap();
            edges.push(EdgeSpanningTreeEdge{
                from: graph.get_vertex(value.from).unwrap(),
                to: graph.get_vertex(value.to).unwrap(),
                weight: value.dist
            });
        }
    }
    Ok(edges)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_kruskal() {
        let mut graph = Graph::new(20);

        graph.add_edge(1, 2, 7.0).unwrap();
        graph.add_edge(2, 1, 7.0).unwrap();
        graph.add_edge(1, 4, 5.0).unwrap();
        graph.add_edge(4, 1, 5.0).unwrap();
        graph.add_edge(2, 3, 8.0).unwrap();
        graph.add_edge(3, 2, 8.0).unwrap();
        graph.add_edge(2, 5, 7.0).unwrap();
        graph.add_edge(5, 2, 7.0).unwrap();
        graph.add_edge(2, 4, 9.0).unwrap();
        graph.add_edge(4, 2, 9.0).unwrap();
        graph.add_edge(3, 5, 5.0).unwrap();
        graph.add_edge(5, 3, 5.0).unwrap();
        graph.add_edge(5, 7, 9.0).unwrap();
        graph.add_edge(7, 5, 9.0).unwrap();
        graph.add_edge(5, 6, 8.0).unwrap();
        graph.add_edge(6, 5, 8.0).unwrap();
        graph.add_edge(5, 4, 15.0).unwrap();
        graph.add_edge(4, 5, 15.0).unwrap();
        graph.add_edge(6, 7, 11.0).unwrap();
        graph.add_edge(7, 6, 11.0).unwrap();
        graph.add_edge(6, 4, 6.0).unwrap();
        graph.add_edge(4, 6, 6.0).unwrap();

        let edges = kruskal(&graph).unwrap();

        let summary_weight: f64 = edges.iter().map(|value| value.weight).sum();
        assert_eq!(39.0, summary_weight);
        let res = edges.iter().map(|value| (value.from.id(), value.to.id())).collect::<Vec<(usize, usize)>>();
        assert_eq!(res, vec![(1, 4), (5, 3), (4, 6), (2, 1), (5, 2), (7, 5)]);
    }
}