/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

use kdtree::KdTree;
use nalgebra::storage::Storage;
use nalgebra::{Const, SVector, VectorSlice};
use petgraph::stable_graph::NodeIndex;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::cspace::CSpace;
use crate::error::{InvalidParamError, Result};
use crate::obstacles::{Obstacle, ObstacleSpace};
use crate::path_planner::{MoveGoal, PathPlanner, RobotSpecs};
use crate::scalar::Scalar;
use crate::trajectories::{FullTrajOwned, FullTrajRefOwned, FullTrajectory};

use super::RrtTree;

/// RRT generic parameters
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(bound(
  serialize = "X: Serialize",
  deserialize = "X: DeserializeOwned"
))]
pub struct RrtParams<X> {
  pub min_cost: X,
  pub portion: X,
  pub delta: X,
}

/// Rrt Serializable State information
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound(
  serialize = "X: Serialize, CS::Traj: Serialize, OS: Serialize",
  deserialize = "X: DeserializeOwned, CS::Traj: DeserializeOwned, OS: DeserializeOwned",
))]
pub struct RrtState<X, CS, OS, const N: usize>
where
  X: Scalar,
  CS: CSpace<X, N>,
  OS: ObstacleSpace<X, CS, N>,
{
  /// All the nodes with their associated costs and edges that connect them
  pub tree: RrtTree<X, CS::Traj, N>,

  /// The current path being travelled from some point to a node in the tree
  pub current_path: Option<(SVector<X, N>, CS::Traj, NodeIndex)>,

  /// The last know current pose of the system
  pub pose: SVector<X, N>,

  /// Tracks the obstacles in the space
  pub obs_space: OS,

  /// Robot specifications
  pub robot_specs: RobotSpecs<X>,

  // Algorithm Parameters
  pub params: RrtParams<X>,
}

impl<X, CS, OS, const N: usize> Clone for RrtState<X, CS, OS, N>
where
  X: Scalar,
  CS: CSpace<X, N>,
  OS: ObstacleSpace<X, CS, N>,
{
  fn clone(&self) -> Self {
    Self {
      tree: self.tree.clone(),
      current_path: self.current_path.clone(),
      pose: self.pose.clone(),
      obs_space: self.obs_space.clone(),
      robot_specs: self.robot_specs.clone(),
      params: self.params.clone(),
    }
  }
}

/// Rrt implementation
pub struct Rrt<X, CS, OS, const N: usize>
where
  X: Scalar,
  CS: CSpace<X, N>,
  OS: ObstacleSpace<X, CS, N>,
{
  /// All the nodes with their associated costs and edges that connect them
  state: RrtState<X, CS, OS, N>,

  // For fast nearest neighbor searching
  kdtree: KdTree<X, NodeIndex, [X; N]>,

  // Configuration space
  cspace: CS,
}

impl<X, CS, OS, const N: usize> PathPlanner<X, CS, OS, N> for Rrt<X, CS, OS, N>
where
  X: Scalar,
  CS: CSpace<X, N> + Send + Sync,
  CS::Traj: Send + Sync,
  OS: ObstacleSpace<X, CS, N> + Send + Sync,
  OS::Obs: Send + Sync,
{
  type Params = RrtParams<X>;
  type State = RrtState<X, CS, OS, N>;

  fn new(
    init: SVector<X, N>,
    goal: SVector<X, N>,
    robot_specs: RobotSpecs<X>,
    cspace: CS,
    obs_space: OS,
    params: Self::Params,
  ) -> Result<Self> {
    let tree = RrtTree::new(goal.clone());
    let current_path = None;
    let pose = init.clone();

    let mut kdtree: KdTree<X, NodeIndex, [X; N]> = KdTree::new(N.into());
    kdtree
      .add(goal.into(), tree.get_goal_idx())
      .expect("kdtree error");

    // Validate robot_radius are greater than 0
    if !(X::zero() < robot_specs.robot_radius) {
      Err(InvalidParamError {
        parameter_name: "robot_specs.robot_radius",
        parameter_value: format!("{:?}", robot_specs.robot_radius),
      })?;
    }

    // Validate sensor_radius are greater than 0
    if !(X::zero() < robot_specs.sensor_radius) {
      Err(InvalidParamError {
        parameter_name: "robot_specs.sensor_radius",
        parameter_value: format!("{:?}", robot_specs.sensor_radius),
      })?;
    }

    // Validate min_cost is in [0, delta)
    if !(X::zero() <= params.min_cost && params.min_cost < params.delta) {
      Err(InvalidParamError {
        parameter_name: "params.min_cost",
        parameter_value: format!("{:?}", params.min_cost),
      })?;
    }

    // Validate portion is between 0 and 1
    if !(X::zero() < params.portion && params.portion < X::one()) {
      Err(InvalidParamError {
        parameter_name: "params.portion",
        parameter_value: format!("{:?}", params.portion),
      })?;
    }

    let state = RrtState {
      tree,
      current_path,
      pose,
      obs_space,
      robot_specs,
      params,
    };

    let new = Self {
      state,
      kdtree,
      cspace,
    };

    // Check that the init and goal locations are in free space
    // Check that the init and goal locations are in free space
    if !new.is_free(&init) {
      Err(InvalidParamError {
        parameter_name: "init",
        parameter_value: format!("{:?}", init),
      })?;
    }
    if !new.is_free(&goal) {
      Err(InvalidParamError {
        parameter_name: "goal",
        parameter_value: format!("{:?}", goal),
      })?;
    }

    Ok(new)
  }

  fn create_node(&mut self) -> &Self::State {
    // Continue to sample points until one is found
    loop {
      if let Some(()) = self.try_create_node() {
        break;
      }
    }
    self.get_state()
  }

  fn sample_node(&mut self) -> Option<&Self::State> {
    self.try_create_node()?;
    Some(self.get_state())
  }

  fn check_sensors(&mut self) {
    let added = self
      .state
      .obs_space
      .check_sensors(&self.state.pose, self.state.robot_specs.sensor_radius);

    if !added.is_empty() {
      // Copy out the added obstacles into a temporary obstacle space
      let obs = self
        .state
        .obs_space
        .get_obstacles(&added[..])
        .into_iter()
        .cloned()
        .collect();
      let added_obs_space = OS::new(obs);

      // Add to environment possibly creating orphans
      self.add_obstacle_to_environment(added_obs_space);

      // Cleanup
      // if the current target has been orphaned, remove it
      if let Some((_, _, move_goal_idx)) = self.state.current_path {
        if self.state.tree.is_orphan(move_goal_idx) {
          self.state.current_path = None;
        }
      }

      // TODO: BUG:
      // removing from tree without removing from kdtree will cause problems
      self.state.tree.clear_orphans();
    }
  }

  fn get_obs_space(&self) -> &OS {
    &self.state.obs_space
  }

  fn get_obs_space_mut(&mut self) -> &mut OS {
    &mut self.state.obs_space
  }

  // fn add_obstacles<I>(&mut self, obstacles: I)
  // where
  //   I: IntoIterator<Item = (ObstacleId, OS)>,
  // {
  //   for obs in obstacles.into_iter() {
  //     // Add to hashmap, replacing if nessasary
  //     self.state.obstacles.insert(obs.0, obs.1);
  //     self.add_obstacle_to_environment(obs.0);
  //   }

  //   // if the current target has been orphaned, remove it
  //   if let Some((_, move_goal_idx)) = self.state.current_path {
  //     if self.state.tree.is_orphan(move_goal_idx) {
  //       self.state.current_path = None;
  //     }
  //   }

  //   self.state.tree.clear_orphans();
  // }

  // fn remove_obstacles(&mut self, obstacles: &[ObstacleId]) {
  //   for obs in obstacles.iter() {
  //     // Remove from hashmap if exists
  //     self.state.obstacles.remove(obs);
  //   }
  // }

  // fn get_obstacles(&self) -> Vec<(ObstacleId, &OS)> {
  //   self
  //     .state
  //     .obstacles
  //     .iter()
  //     .map(|(&obs_id, obs)| (obs_id, obs))
  //     .collect()
  // }

  fn get_cspace(&self) -> &CS {
    &self.cspace
  }

  fn get_cspace_mut(&mut self) -> &mut CS {
    &mut self.cspace
  }

  fn check_nodes(&mut self, use_obs_space: bool, use_cspace: bool) {
    if !(use_obs_space || use_cspace) {
      // Nothing to do
      return;
    }

    let all_nodes = self.state.tree.all_nodes().collect::<Vec<_>>();
    let tree = &self.state.tree;
    let cspace = &self.cspace;
    let obs_space = &self.state.obs_space;

    let iter = all_nodes.into_par_iter().filter_map(|u_idx| {
      let u = tree.get_node(u_idx);

      // Check cspace and/or obs_space if required
      if (use_cspace && !cspace.is_free(u))
        || (use_obs_space && !obs_space.is_free(u))
      {
        Some(u_idx)
      } else {
        None
      }
    });
    let vec = iter.collect::<Vec<_>>();

    for u_idx in vec {
      self.state.tree.add_orphan(u_idx);
    }
    // TODO: BUG:
    // removing from tree without removing from kdtree will cause problems
    self.state.tree.clear_orphans();
  }

  fn update_pose(
    &mut self,
    pose: SVector<X, N>,
    nearest: bool,
  ) -> Option<MoveGoal<X, N>> {
    self.state.pose = pose;
    log::info!("Updating pose to {:?}", <[X; N]>::from(pose));

    // Check sensors from this new pose
    self.check_sensors();

    if !nearest {
      // Check that the next move goal is still valid for this new pose,
      // if so, return it
      // otherwise, determine a new move goal
      if let Some((start_pose, _, move_goal_idx)) = self.state.current_path {
        let move_goal = self.state.tree.get_node(move_goal_idx);
        let pose_ref = pose.index((.., 0));
        let mg_ref = move_goal.index((.., 0));

        if let Some(new_trajectory) = self.cspace.trajectory(pose_ref, mg_ref) {
          if self.trajectory_free(&new_trajectory) {
            // Trajectory is valid
            let init_dist_err = self.cspace.cost(&start_pose, move_goal);
            let rel_dist_err = self.cspace.cost(&pose, move_goal);

            if init_dist_err > rel_dist_err {
              // We are closer to the move goal than initially

              if init_dist_err - rel_dist_err
                > init_dist_err * self.state.params.portion
                || rel_dist_err < self.state.params.min_cost
              {
                // We have reached the goal we were aiming for

                // Check if we have reach the global goal node
                if self.state.tree.get_goal_idx() == move_goal_idx {
                  log::info!("Reached the finish!");
                  return Some(MoveGoal::Finished);
                }

                // Find the first viable parent in the optimal path that is more
                // than min_cost away
                log::info!("Reached move goal, looking for next along path");
                let res = self.find_move_goal_along_path(&pose, move_goal_idx);

                // Update if found
                if let Some((new_trajectory, new_move_goal_idx)) = res {
                  self.state.current_path =
                    Some((pose, new_trajectory, new_move_goal_idx));
                  log::info!("New move goal found");
                  return Some(MoveGoal::New(*self.get_current_path()?.end()));
                } else {
                  log::info!("No valid move goal along path");
                }
              } else {
                // Keep the same move goal
                log::info!("Keeping same move goal");
                return Some(MoveGoal::Old(*self.get_current_path()?.end()));
              }
            } else {
              // We are farther from the move goal than initial

              if rel_dist_err - init_dist_err
                > init_dist_err * self.state.params.portion
              {
                // We are out of range of the goal we were aiming for
                // Falling through to look for a new path
                log::info!("Out of range of move goal");
              } else {
                // Keep the same move goal
                log::info!("Keeping same move goal");
                return Some(MoveGoal::Old(*self.get_current_path()?.end()));
              }
            }
          } else {
            // Trajectory is blocked by an obstacle
            log::info!("Current move goal blocked by obstacle");
          }
        } else {
          // The trajectory is invalid
          log::info!("Trajectory to current move goal infesible");
        }
      }
    }

    // The current move goal is invalid, find a new one
    log::info!("Move goal invalid, looking for a new one");
    self.state.current_path = self.find_new_path(pose);
    Some(MoveGoal::New(*self.get_current_path()?.end()))
  }

  fn get_tree(&self) -> Vec<FullTrajRefOwned<X, CS::Traj, N>> {
    unimplemented!()
  }

  fn get_current_path(&self) -> Option<FullTrajOwned<X, CS::Traj, N>> {
    let path = self.state.current_path.as_ref()?.clone();
    let start = path.0;
    let end = self.state.tree.get_node(path.2).clone();
    let traj = path.1;

    Some(FullTrajOwned::new(start, end, traj))
  }

  fn get_path_to_goal(&self) -> Option<Vec<SVector<X, N>>> {
    Some(
      self
        .state
        .tree
        .get_optimal_path(self.state.current_path.as_ref()?.2)?
        .map(|node_idx| self.state.tree.get_node(node_idx).clone())
        .collect(),
    )
  }

  fn get_last_pose(&self) -> &SVector<X, N> {
    &self.state.pose
  }

  fn get_state(&self) -> &Self::State {
    &self.state
  }

  fn get_goal(&self) -> &SVector<X, N> {
    self.state.tree.get_goal()
  }

  fn count(&self) -> usize {
    assert_eq!(self.kdtree.size(), self.state.tree.node_count());
    self.state.tree.node_count() + 1
  }
}

impl<X, CS, OS, const N: usize> Rrt<X, CS, OS, N>
where
  X: Scalar,
  CS: CSpace<X, N> + Send + Sync,
  CS::Traj: Send + Sync,
  OS: ObstacleSpace<X, CS, N> + Send + Sync,
  OS::Obs: Send + Sync,
{
  /// Try to place a new node in the tree, returns None when unsuccessful
  fn try_create_node(&mut self) -> Option<()> {
    let mut v = self.cspace.sample();
    let (cost, v_nearest_idx) = self.nearest(&v);

    if cost > self.state.params.delta {
      let v_nearest = self.state.tree.get_node(v_nearest_idx);
      self
        .cspace
        .saturate(&mut v, v_nearest, self.state.params.delta);
    }

    match self.is_free(&v) {
      true => {
        self.extend(v, v_nearest_idx)?;
        Some(())
      }
      false => None,
    }
  }

  /// Find the nearest node to v and the cost to get to it
  fn nearest(&self, v: &SVector<X, N>) -> (X, NodeIndex) {
    let vec = self
      .kdtree
      .nearest(v.into(), 1, &|a, b| {
        let a = VectorSlice::<X, Const<N>>::from_slice(a);
        let b = VectorSlice::<X, Const<N>>::from_slice(b);
        self.cspace.cost(&a, &b)
      })
      .unwrap();
    let (cost, &v_nearest_idx) = vec.first().unwrap();
    (*cost, v_nearest_idx)
  }

  /// Try to extend the tree to include the given node with parent u
  fn extend(
    &mut self,
    v: SVector<X, N>,
    u_idx: NodeIndex,
  ) -> Option<NodeIndex> {
    let trajectory = self.check_parent(&v, u_idx)?;

    let (v_idx, _) = self.state.tree.add_node(v, u_idx, trajectory);
    self
      .kdtree
      .add(self.state.tree.get_node(v_idx).clone().into(), v_idx)
      .expect("kdtree error");

    Some(v_idx)
  }

  /// Returns Some(trajectory) if that trajectory is fesiable to parent u
  fn check_parent(
    &self,
    v: &SVector<X, N>,
    u_idx: NodeIndex,
  ) -> Option<CS::Traj> {
    let u = self.state.tree.get_node(u_idx);
    let v_ref = v.index((.., 0));
    let u_ref = u.index((.., 0));

    let trajectory = self.cspace.trajectory(v_ref, u_ref)?;

    if trajectory.cost() > self.state.params.delta {
      return None;
    }

    match self.trajectory_free(&trajectory) {
      true => Some(trajectory.to_trajectory()),
      false => None,
    }
  }

  /// Updates the tree nodes and edges that become blocked by this obstacle
  fn add_obstacle_to_environment<O>(&mut self, obstacle: O)
  where
    O: Obstacle<X, CS, N> + Send + Sync,
  {
    // Find all the edges that intersect with the obstacle
    let tree_ref = &self.state.tree;
    let iter = tree_ref.all_edges().collect::<Vec<_>>();
    let iter = iter.into_par_iter().filter_map(|edge_idx| {
      let trajectory = self.state.tree.get_trajectory(edge_idx);
      match obstacle.trajectory_free(&trajectory) {
        true => None,
        false => {
          let (v_idx, _u_idx) = self.state.tree.get_endpoints(edge_idx);
          Some(v_idx)
        }
      }
    });

    // Compute this iterator
    let vec: Vec<_> = iter.collect();

    for v_idx in vec {
      self.state.tree.add_orphan(v_idx);
    }
  }

  /// Check that the given coordinate does not intersect with any obstacles
  fn is_free(&self, v: &SVector<X, N>) -> bool {
    self.cspace.is_free(v) && self.state.obs_space.is_free(v)
  }

  /// Check that the given trajectory does not intersect with any obstacles
  fn trajectory_free<TF, S1, S2>(&self, t: &TF) -> bool
  where
    TF: FullTrajectory<X, CS::Traj, S1, S2, N>,
    S1: Storage<X, Const<N>>,
    S2: Storage<X, Const<N>>,
  {
    self.state.obs_space.trajectory_free(t)
  }

  /// Find the first viable parent in the optimal path that is more than min_cost away
  ///
  /// panics if move_goal_idx is an orphan
  fn find_move_goal_along_path(
    &self,
    pose: &SVector<X, N>,
    move_goal_idx: NodeIndex,
  ) -> Option<(CS::Traj, NodeIndex)> {
    // This is an invariant, if the move_goal exists then the optimal path exists
    let mut optimal_path_iter =
      self.state.tree.get_optimal_path(move_goal_idx).unwrap();

    optimal_path_iter.next(); // Pop off the current goal (move_goal_idx)

    // Seach the path for the first viable node that is more than min_cost away
    for node_idx in optimal_path_iter {
      let node = self.state.tree.get_node(node_idx);

      if let Some(trajectory) =
        self.cspace.trajectory(pose.clone(), node.clone())
      {
        if self.trajectory_free(&trajectory) {
          // Valid trajectory, see if it is longer than min_cost
          let cost = trajectory.cost();
          if self.state.params.min_cost < cost {
            return Some((trajectory.to_trajectory(), node_idx));
          }
          continue;
        }
      }
      // Invalid trajectory, cut off search if farther than delta
      let cost = self.cspace.cost(pose, node);
      if self.state.params.delta < cost {
        log::info!("Cutting search along path short");
        return None;
      }
    }
    None
  }

  /// The current move goal is unreachable, this looks for a new one
  /// Returns Some new move goal if found
  fn find_new_path(
    &self,
    pose: SVector<X, N>,
  ) -> Option<(SVector<X, N>, CS::Traj, NodeIndex)> {
    // Create Iterator over all points in cost order
    let cost_func = |a: &[X], b: &[X]| {
      let a = VectorSlice::<X, Const<N>>::from_slice(a);
      let b = VectorSlice::<X, Const<N>>::from_slice(b);
      self.cspace.cost(&a, &b)
    };
    let iter = self
      .kdtree
      .iter_nearest(pose.as_slice(), &cost_func)
      .unwrap();

    // Look for a valid parent
    log::info!("Searching nearest nodes");
    for (cost, &u_idx) in iter {
      if self.state.params.delta < cost {
        log::info!("Cutting search short");
        break;
      }

      if let Some(trajectory) = self.check_parent(&pose, u_idx) {
        log::info!("Found new move goal");
        return Some((pose, trajectory, u_idx));
      }
    }
    log::info!("End of search, no new move goal found");
    None
  }
}

#[cfg(test)]
mod tests {

  use parry3d::math::Isometry;
  use parry3d::shape::{Ball, Cuboid, SharedShape};
  use rand::SeedableRng;

  use crate::cspace::EuclideanSpace;
  use crate::obstacles::obstacles_3d_f32::{Obstacle3df32, ObstacleSpace3df32};
  use crate::rng::RNG;
  use crate::util::bounds::Bounds;

  use super::*;

  const SEED: u64 = 0xe580e2e93fd6b040;

  #[test]
  fn test_rrt() {
    let init = [5.0, 0.5, 5.0].into();
    let goal = [-5.0, 0.5, -5.0].into();

    let robot_specs = RobotSpecs {
      robot_radius: 0.1,
      sensor_radius: 2.0,
    };

    let bounds = Bounds::new([-5.0, 0.0, -5.0].into(), [5.0, 1.0, 5.0].into());
    let cspace = EuclideanSpace::new(bounds, RNG::seed_from_u64(SEED)).unwrap();

    let ball = Obstacle3df32::with_offset(
      SharedShape::new(Ball::new(0.5)),
      Isometry::translation(0.0, 0.5, 0.0),
    );

    let cube = Obstacle3df32::with_offset(
      SharedShape::new(Cuboid::new([0.5, 0.5, 0.5].into())),
      Isometry::translation(2.5, 0.5, 2.5),
    );

    let obs_space = ObstacleSpace3df32::from(vec![ball, cube]);

    let params = RrtParams {
      min_cost: 0.0,
      portion: 0.1,
      delta: 1.0,
    };

    let mut rrt =
      Rrt::new(init, goal, robot_specs, cspace, obs_space, params).unwrap();

    let path = loop {
      rrt.create_node();
      rrt.update_pose(init, false);
      if let Some(path) = rrt.get_path_to_goal() {
        break path;
      }
    };

    let mut cost = 0.0;
    for i in 0..path.len() - 1 {
      cost += path[i].metric_distance(&path[i + 1]);
    }
    println!("{:?}", rrt.count());
    println!("{:?}", path);
    println!("{:?}", cost);
  }

  #[test]
  fn test_serialize_rrt_state() {
    let init = [5.0, 0.5, 5.0].into();
    let goal = [-5.0, 0.5, -5.0].into();

    let robot_specs = RobotSpecs {
      robot_radius: 0.1,
      sensor_radius: 2.0,
    };

    let bounds = Bounds::new([-5.0, 0.0, -5.0].into(), [5.0, 1.0, 5.0].into());
    let cspace = EuclideanSpace::new(bounds, RNG::seed_from_u64(SEED)).unwrap();

    let ball = Obstacle3df32::with_offset(
      SharedShape::new(Ball::new(0.5)),
      Isometry::translation(0.0, 0.5, 0.0),
    );

    let cube = Obstacle3df32::with_offset(
      SharedShape::new(Cuboid::new([0.5, 0.5, 0.5].into())),
      Isometry::translation(2.5, 0.5, 2.5),
    );

    let obs_space = ObstacleSpace3df32::from(vec![ball, cube]);

    let params = RrtParams {
      min_cost: 0.0,
      portion: 0.1,
      delta: 1.0,
    };

    let mut rrt =
      Rrt::new(init, goal, robot_specs, cspace, obs_space, params).unwrap();

    loop {
      rrt.create_node();
      rrt.update_pose(init, false);
      if let Some(_) = rrt.get_path_to_goal() {
        break;
      }
    }

    let state = rrt.get_state();
    let v = bincode::serialize(&state).unwrap();
    let _: RrtState<f32, EuclideanSpace<f32, 3>, ObstacleSpace3df32, 3> =
      bincode::deserialize(&v).unwrap();
  }
}
