use std::marker::PhantomData;
use crate::{Vertex, Graph};
use crate::algorithms::{VisitorDFSAction, VisitorDFS, dfs_with_visitor};
use crate::error::{GraphError, ErrorKind};

pub struct SubTreesSize<'a, T> {
    values: Vec<Option<usize>>,
    phantom: PhantomData<&'a T>,
}

impl <'a, T> SubTreesSize<'a, T> where T: Copy {
    pub fn get_subtree_size(&self, target: &Vertex<T>) -> Option<usize> {
        self.values[target.id()]
    }
}

struct CustomVisitor {
    values: Vec<Option<usize>>,
    cycle: bool
}

impl<T> VisitorDFS<T> for  CustomVisitor {
    fn entry_to_vertex_event(&mut self, vertex: &Vertex<T>) -> Result<VisitorDFSAction, GraphError> {
        self.values[vertex.id] = Some(1);
        Ok(VisitorDFSAction::Nothing)
    }
    fn exit_from_white_vertex_event(&mut self, vertex: &Vertex<T>, parent: &Vertex<T>, _grand_parent: Option<&Vertex<T>>) -> Result<VisitorDFSAction, GraphError> {
        self.values[parent.id] = Some(self.values[parent.id].unwrap() + self.values[vertex.id].unwrap());
        Ok(VisitorDFSAction::Nothing)
    }

    fn entry_to_grey_vertex_event(&mut self, vertex: &Vertex<T>, _parent: &Vertex<T>, grand_parent: Option<&Vertex<T>>) -> Result<VisitorDFSAction, GraphError> {
        if !self.cycle && vertex.id != grand_parent.unwrap().id {
            self.cycle = true;
            return Ok(VisitorDFSAction::Break);
        }
        Ok(VisitorDFSAction::Nothing)
    }
}

///
/// Find subtree sizes. The algorithm finds the size of each subtree. The graph must be a tree, i.e. a connected acyclic graph
/// Algorithmic complexity - O(|E|), where |E| is the number of edges in the graph.
///
/// ```
/// use luka::{Graph, algorithms};
///
/// let mut graph = Graph::new(12);
///
/// graph.add_edge(1, 4, 0).unwrap();
/// graph.add_edge(1, 2, 0).unwrap();
/// graph.add_edge(4, 11, 0).unwrap();
/// graph.add_edge(4, 12, 0).unwrap();
/// graph.add_edge(12, 3, 0).unwrap();
/// graph.add_edge(2, 5, 0).unwrap();
/// graph.add_edge(2, 6, 0).unwrap();
/// graph.add_edge(5, 9, 0).unwrap();
/// graph.add_edge(5, 10, 0).unwrap();
/// graph.add_edge(6, 7, 0).unwrap();
/// graph.add_edge(7, 8, 0).unwrap();
///
/// let subtrees_size = algorithms::find_subtrees_size(&graph, graph.get_vertex(1).unwrap()).unwrap();
/// assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(1).unwrap()), Some(12));
/// assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(2).unwrap()), Some(7));
/// assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(4).unwrap()), Some(4));
/// ```

pub fn find_subtrees_size<'a, T>(graph: &'a Graph<T>, from: &'a Vertex<T>) -> Result<SubTreesSize<'a, T>, GraphError> where T: Default + Copy {
    let mut visitor = CustomVisitor{
        values: vec![None; graph.size()],
        cycle: false
    };
    dfs_with_visitor(graph, from , &mut visitor)?;
    if visitor.cycle {
        return Err(GraphError::Regular(ErrorKind::TreeContainsCycle));
    }
    Ok(SubTreesSize{ values: visitor.values, phantom: PhantomData })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_find_subtrees_size() {
        let mut graph = Graph::new(12);

        graph.add_edge(1, 4, 0).unwrap();
        graph.add_edge(1, 2, 0).unwrap();
        graph.add_edge(4, 11, 0).unwrap();
        graph.add_edge(4, 12, 0).unwrap();
        graph.add_edge(12, 3, 0).unwrap();
        graph.add_edge(2, 5, 0).unwrap();
        graph.add_edge(2, 6, 0).unwrap();
        graph.add_edge(5, 9, 0).unwrap();
        graph.add_edge(5, 10, 0).unwrap();
        graph.add_edge(6, 7, 0).unwrap();
        graph.add_edge(7, 8, 0).unwrap();

        let subtrees_size = find_subtrees_size(&graph, graph.get_vertex(1).unwrap()).unwrap();
        assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(1).unwrap()), Some(12));
        assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(2).unwrap()), Some(7));
        assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(4).unwrap()), Some(4));
    }
}