/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

pub use petgraph::stable_graph::{EdgeIndex, NodeIndex};

////////////////////////////////////////////////////////////////////////////////

use std::collections::HashSet;
use std::marker::PhantomData;
use std::ops::{Index, IndexMut};

use nalgebra::SVector;
use petgraph::stable_graph::{
  DefaultIx, EdgeIndices, EdgeReference, Neighbors, NodeIndices, StableDiGraph,
  WalkNeighbors,
};
use petgraph::visit::EdgeRef;
use petgraph::Direction;
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::scalar::Scalar;
use crate::trajectories::{FullTrajRefOwned, Trajectory};

/// Iterator over all the node indices of the graph
pub struct NodeIter<'a, X, const N: usize> {
  nodes: NodeIndices<'a, SVector<X, N>>,
}

impl<'a, X, const N: usize> NodeIter<'a, X, N> {
  fn new<T>(graph: &'a StableDiGraph<SVector<X, N>, T>) -> Self
  where
    T: Trajectory<X, N>,
  {
    Self {
      nodes: graph.node_indices(),
    }
  }
}

impl<'a, X, const N: usize> Iterator for NodeIter<'a, X, N> {
  type Item = NodeIndex;

  fn next(&mut self) -> Option<Self::Item> {
    self.nodes.next()
  }
}

/// Iterator over all the edge indices of the graph
pub struct EdgeIter<'a, X, T, const N: usize>
where
  T: Trajectory<X, N>,
{
  edges: EdgeIndices<'a, T>,
  phantom_x: PhantomData<X>,
}

impl<'a, X, T, const N: usize> EdgeIter<'a, X, T, N>
where
  X: Scalar,
  T: Trajectory<X, N>,
{
  fn new(graph: &'a StableDiGraph<SVector<X, N>, T>) -> Self {
    Self {
      edges: graph.edge_indices(),
      phantom_x: PhantomData,
    }
  }
}

impl<'a, X, T, const N: usize> Iterator for EdgeIter<'a, X, T, N>
where
  T: Trajectory<X, N>,
{
  type Item = EdgeIndex;

  fn next(&mut self) -> Option<Self::Item> {
    self.edges.next()
  }
}

/// Iterator over edge indices of the graph in the optimal subtree
pub struct OptimalPathIter<'a, X, T, const N: usize>
where
  X: Scalar,
  T: Trajectory<X, N>,
{
  graph: &'a RrtTree<X, T, N>,
  next_node: Option<NodeIndex>,
}

impl<'a, X, T, const N: usize> OptimalPathIter<'a, X, T, N>
where
  X: Scalar,
  T: Trajectory<X, N>,
{
  fn new(graph: &'a RrtTree<X, T, N>, node: NodeIndex) -> Self {
    Self {
      graph,
      next_node: Some(node),
    }
  }

  pub fn detach(self) -> OptimalPathWalker {
    OptimalPathWalker::new(self.next_node)
  }
}

impl<'a, X, T, const N: usize> Iterator for OptimalPathIter<'a, X, T, N>
where
  X: Scalar,
  T: Trajectory<X, N>,
{
  type Item = NodeIndex;

  fn next(&mut self) -> Option<Self::Item> {
    let node = self.next_node?;
    match self.graph.parent(node) {
      Some(parent) => self.next_node = Some(parent),
      None => self.next_node = None,
    }

    Some(node)
  }
}

/// Iterator over edge indices of the graph in the optimal subtree
pub struct OptimalPathWalker {
  next_node: Option<NodeIndex>,
}

impl OptimalPathWalker {
  fn new(node: Option<NodeIndex>) -> Self {
    Self { next_node: node }
  }

  pub fn next<X, T, const N: usize>(
    &mut self,
    g: &RrtTree<X, T, N>,
  ) -> Option<NodeIndex>
  where
    X: Scalar,
    T: Trajectory<X, N>,
  {
    let node = self.next_node?;
    match g.parent(node) {
      Some(parent) => self.next_node = Some(parent),
      None => self.next_node = None,
    }

    Some(node)
  }
}

/// Tree structure for holding the current tree
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(
  serialize = "X: Serialize, T: Serialize",
  deserialize = "X: DeserializeOwned, T: DeserializeOwned",
))]
pub struct RrtTree<X, T, const N: usize>
where
  X: Scalar,
  T: Trajectory<X, N>,
{
  goal_idx: NodeIndex,
  graph: StableDiGraph<SVector<X, N>, T>,
  orphans: HashSet<NodeIndex>,
  #[serde(skip)]
  phantom_x: PhantomData<X>,
}

impl<X, T, const N: usize> RrtTree<X, T, N>
where
  X: Scalar,
  T: Trajectory<X, N>,
{
  /// Creates new graph with goal as the root node
  pub fn new(goal: SVector<X, N>) -> Self {
    let mut graph = StableDiGraph::new();
    let goal_idx = graph.add_node(goal);

    let orphans = HashSet::new();

    Self {
      goal_idx,
      graph,
      orphans,
      phantom_x: PhantomData,
    }
  }

  /// Returns the number of nodes in the graph
  pub fn node_count(&self) -> usize {
    self.graph.node_count()
  }

  /// Returns the index of the goal node
  pub fn get_goal_idx(&self) -> NodeIndex {
    self.goal_idx
  }

  /// Returns reference to the goal node
  pub fn get_goal(&self) -> &SVector<X, N> {
    self.get_node(self.goal_idx)
  }

  /// Returns iterator over all nodes in the graph
  pub fn all_nodes(&self) -> NodeIter<X, N> {
    NodeIter::new(&self.graph)
  }

  /// Returns iterator over all edges in the graph
  pub fn all_edges(&self) -> EdgeIter<X, T, N> {
    EdgeIter::new(&self.graph)
  }

  /// Returns the NodeIndex of the parent of a in the optimal subtree if exists
  /// Returns None if no parent edge
  pub fn parent(&self, node: NodeIndex) -> Option<NodeIndex> {
    Some(self.parent_edge(node)?.target())
  }

  /// Returns the edge directed at the nodes parent in the optimal subtree if exists
  fn parent_edge(&self, node: NodeIndex) -> Option<EdgeReference<T>> {
    self.graph.edges_directed(node, Direction::Outgoing).next()
  }

  /// Looks up edge from node -> parent to see if parent is the parent of node
  pub fn is_parent(&self, node: NodeIndex, parent: NodeIndex) -> bool {
    self.graph.find_edge(node, parent).is_some()
  }

  /// Returns iterator over all children of a in the optimal subtree.
  /// Iterator will be empty for leaf nodes
  pub fn children(&self, node: NodeIndex) -> Neighbors<T, DefaultIx> {
    self.graph.neighbors_directed(node, Direction::Incoming)
  }

  /// Returns walker over all children of a in the optimal subtree.
  /// Iterator will be empty for leaf nodes
  pub fn children_walker(&self, node: NodeIndex) -> WalkNeighbors<DefaultIx> {
    self
      .graph
      .neighbors_directed(node, Direction::Incoming)
      .detach()
  }

  /// Looks up edge from child -> node to see if node is the parent of child
  pub fn is_child(&self, node: NodeIndex, child: NodeIndex) -> bool {
    self.is_parent(child, node)
  }

  /// Add node and children of node to the internal set of orphans
  /// This garentees that the children of any orphan node are also marked as orphans
  /// as soon as they are added
  /// Called recursively
  pub fn add_orphan(&mut self, node: NodeIndex) {
    self.orphans.insert(node);

    let mut children = self.children_walker(node);
    while let Some(child_idx) = children.next_node(&self.graph) {
      // Only continue if child is not already an orphan,
      // This works because we guarantee that any orphan node already has
      // had it's children orphaned
      if !self.is_orphan(child_idx) {
        self.add_orphan(child_idx);
      }
    }
  }

  /// Remove node from the internal list of orphans, doesn't modifiy the graph
  pub fn remove_orphan(&mut self, node: NodeIndex) {
    self.orphans.remove(&node);
  }

  /// Checks if node is in the internal set of orphans
  pub fn is_orphan(&self, node: NodeIndex) -> bool {
    self.orphans.contains(&node)
  }

  /// Get iterator over all orphans
  pub fn orphans(&self) -> impl Iterator<Item = NodeIndex> + '_ {
    self.orphans.iter().map(|&x| x)
  }

  /// Remove all orphan nodes from the graph and the orphan set
  pub fn clear_orphans(&mut self) {
    let orphans: Vec<_> = self.orphans().collect();
    for orphan_idx in orphans {
      self.graph.remove_node(orphan_idx);
    }
    self.orphans.clear();
  }

  /// Adds a node to the graph and returns it's index in addition fo the edge index to the parent
  pub fn add_node(
    &mut self,
    node: SVector<X, N>,
    parent: NodeIndex,
    trajectory: T,
  ) -> (NodeIndex, EdgeIndex) {
    let node_idx = self.graph.add_node(node);
    let edge_idx = self.update_edge(node_idx, parent, trajectory);
    (node_idx, edge_idx)
  }

  /// Sets the new parent and removes any existing parent
  pub fn update_edge(
    &mut self,
    node: NodeIndex,
    new_parent: NodeIndex,
    new_trajectory: T,
  ) -> EdgeIndex {
    self.remove_any_parents(node);
    self.graph.update_edge(node, new_parent, new_trajectory)
  }

  /// The given node will have no outgoing edges after this function
  /// Returns true if any parents were removed
  fn remove_any_parents(&mut self, node: NodeIndex) -> bool {
    let edges = self
      .graph
      .edges_directed(node, Direction::Outgoing)
      .map(|edge_ref| edge_ref.id());
    let edges: Vec<_> = edges.collect();

    let removed = edges.len() > 0;
    for edge_idx in edges {
      self.graph.remove_edge(edge_idx);
    }
    removed
  }

  /// Returns an iterator of edges over the optimal path to the goal node
  /// Returns None if no such path exists
  pub fn get_optimal_path(
    &self,
    node: NodeIndex,
  ) -> Option<OptimalPathIter<X, T, N>> {
    match self.is_orphan(node) {
      true => None,
      false => Some(OptimalPathIter::new(self, node)),
    }
  }

  /// Returns a refernce the specified node
  ///
  /// panics if `idx` is invalid
  pub fn get_node(&self, idx: NodeIndex) -> &SVector<X, N> {
    self.graph.index(idx)
  }

  /// Returns a mutable refernce the specified node
  ///
  /// panics if `idx` is invalid
  pub fn get_node_mut(&mut self, idx: NodeIndex) -> &mut SVector<X, N> {
    self.graph.index_mut(idx)
  }

  /// Returns a refernce the specified edge
  ///
  /// panics if `idx` is invalid
  pub fn get_edge(&self, idx: EdgeIndex) -> &T {
    self.graph.index(idx)
  }

  /// Returns source and target endpoints of an edge
  ///
  /// panics if `idx` is invalid
  pub fn get_endpoints(&self, idx: EdgeIndex) -> (NodeIndex, NodeIndex) {
    self.graph.edge_endpoints(idx).unwrap()
  }

  /// Returns the trajectory stored at the specified edge
  ///
  /// panics if `idx` is invalid
  pub fn get_trajectory(&self, idx: EdgeIndex) -> FullTrajRefOwned<X, T, N> {
    let (start_idx, end_idx) = self.get_endpoints(idx);
    let start = self.get_node(start_idx);
    let end = self.get_node(end_idx);
    let traj_data = self.get_edge(idx);
    FullTrajRefOwned::new(start, end, traj_data)
  }
}

#[cfg(test)]
mod tests {

  use crate::trajectories::EuclideanTrajectory;

  use super::*;

  #[test]
  fn test_rrt_tree_parent() {
    let goal_coord = [1.5, 1.5].into();

    let mut g = RrtTree::new(goal_coord);
    let goal = g.get_goal_idx();

    let n1_coord = [2.0, 2.0].into();
    let n1 = g.graph.add_node(n1_coord);

    g.update_edge(n1, goal, EuclideanTrajectory::new());

    let n2_coord = [-2.0, -2.0].into();
    let n2 = g.graph.add_node(n2_coord);

    g.update_edge(n2, goal, EuclideanTrajectory::new());

    // Parents
    assert_eq!(g.parent(goal), None);
    assert_eq!(g.parent(n1), Some(goal));
    assert_eq!(g.parent(n2), Some(goal));
  }

  #[test]
  fn test_rrt_tree_children() {
    let goal_coord = [1.5, 1.5].into();

    let mut g = RrtTree::new(goal_coord);
    let goal = g.get_goal_idx();

    let n1_coord = [2.0, 2.0].into();
    let n1 = g.graph.add_node(n1_coord);

    g.update_edge(n1, goal, EuclideanTrajectory::new());

    let n2_coord = [-2.0, -2.0].into();
    let n2 = g.graph.add_node(n2_coord);

    g.update_edge(n2, goal, EuclideanTrajectory::new());

    // Childern
    let mut iter = g.children(goal);
    assert_eq!(iter.next(), Some(n2));
    assert_eq!(iter.next(), Some(n1));
    assert_eq!(iter.next(), None);

    let mut iter = g.children(n1);
    assert_eq!(iter.next(), None);

    let mut iter = g.children(n2);
    assert_eq!(iter.next(), None);
  }

  #[test]
  fn test_rrt_tree_children_walker() {
    let goal_coord = [1.5, 1.5].into();

    let mut g = RrtTree::new(goal_coord);
    let goal = g.get_goal_idx();

    let n1_coord = [2.0, 2.0].into();
    let n1 = g.graph.add_node(n1_coord);

    g.update_edge(n1, goal, EuclideanTrajectory::new());

    let n2_coord = [-2.0, -2.0].into();
    let n2 = g.graph.add_node(n2_coord);

    g.update_edge(n2, goal, EuclideanTrajectory::new());

    // Childern Walker
    let mut iter = g.children_walker(goal);
    assert_eq!(iter.next_node(&g.graph), Some(n2));
    assert_eq!(iter.next_node(&g.graph), Some(n1));
    assert_eq!(iter.next_node(&g.graph), None);

    let mut iter = g.children_walker(n1);
    assert_eq!(iter.next_node(&g.graph), None);

    let mut iter = g.children_walker(n2);
    assert_eq!(iter.next_node(&g.graph), None);
  }
}
