/*
MIT License

Copyright (c) 2021 University of Bristol Flight Laboratory

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Author: Mickey Li <mickeyhli@outlook.com>
*/

//! This module contains the core structures related to motion planning.
use std::clone::Clone;
use std::fmt;
use ndarray::{Array, Array1, Array2, array};
use rand::{thread_rng, seq::IteratorRandom, seq::SliceRandom};

use pyo3::prelude::*;

pub mod ss;
pub mod objects;
pub mod utils;
pub mod solvers;
pub mod trajectory;

/// Used by original Motion Planner to specify the solver method
#[derive(PartialEq, Clone)]
pub enum MPMethod {
    BRUTEFORCE,
    HEURISTIC
}
impl MPMethod {
    pub fn from_string(val: String) -> MPMethod {
        match val.to_lowercase().as_str() {
            "bruteforce"    => MPMethod::BRUTEFORCE,
            "heuristic"     => MPMethod::HEURISTIC,
            _ => panic!("Method not recognised") // Todo replace with actual conversion
        }
    }
}
// To use the `{}` marker, the trait `fmt::Display` must be implemented
// manually for the type.
impl fmt::Display for MPMethod {
    // This trait requires `fmt` with this exact signature.
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::BRUTEFORCE => write!(f, "BruteForce"),
            Self::HEURISTIC => write!(f, "Heuristic")
        }
    }
}

/// Internally specify parameters of the planner
pub struct MPParams {
    pub dt: usize,
    pub epsilon: f64,
    pub method: MPMethod
}

/// Define the problem to solve
#[derive(FromPyObject)]
#[derive(Debug, Clone)]
pub struct MPProblem {
    pub time: f64,
    pub x: MPComponent,
    pub y: MPComponent,
    pub z: MPComponent,
    pub objects: objects::ObjectList
}
impl MPProblem {
    pub fn new(x: MPComponent, y:MPComponent, z:MPComponent, objects: objects::ObjectList) -> MPProblem {
        MPProblem {
            time: 0.0,
            x, y, z,
            objects
        }
    }
}

/// Define a Single Component Axis Solution Position
pub struct MPCompPositions{
    xi: f64,
    xg: f64,
    yi: f64,
    yg: f64,
    zi: f64,
    zg: f64,
}
impl MPCompPositions {
    pub fn new(xi: f64,xg: f64,yi: f64,yg: f64,zi: f64,zg: f64,) -> Self{
            Self {xi,xg,yi,yg,zi,zg}
    }
    /// Return array by component
    pub fn by_comp(&self) -> Array2<f64> {
        array![
            [self.xi, self.xg],
            [self.yi, self.yg],
            [self.zi, self.zg]
        ]
    }
    /// Return array by initial/goal
    pub fn by_ig(&self) -> Array2<f64> {
        array![
            [self.xi, self.yi, self.zi],
            [self.xg, self.yg, self.zg]
        ]
    }
}

/// One solution containing one combination of the sampled x, y, z component solution.
#[derive(Clone, Debug, Copy)]
pub struct MPSolutionElem {
    pub x: ss::SolutionSetVectorElem,
    pub y: ss::SolutionSetVectorElem,
    pub z: ss::SolutionSetVectorElem
}
impl MPSolutionElem{
    pub fn from_component_solution(
        x: &MPComponentSolutionElem,
        y: &MPComponentSolutionElem,
        z: &MPComponentSolutionElem)
    -> Self {
        Self {
            x: x.ss.clone(), y:y.ss.clone(), z: z.ss.clone()
        }
    }

    pub fn get_positions_as_array(&self) -> Array1<f64> {
        array![ self.x.xi, self.x.xg,
        self.y.xi, self.y.xg,
        self.z.xi, self.z.xg]
    }

    pub fn get_position(&self) -> MPCompPositions {
        MPCompPositions::new(
            self.x.xi, self.x.xg,
            self.y.xi, self.y.xg,
            self.z.xi, self.z.xg
        )
    }

    /// Return the maximum time required out of all components to reach goal for this solution
    pub fn get_max_time(&self) -> f64 {
        self.x.tg.max(self.y.tg.max(self.z.tg))
    }

    /// Generate a 4pl trajectory from this solution
    /// - `time_start`: the time at which the trajectory starts
    /// - `num_samples`: the number of samples to take of the trajectory, i.e. the resolution
    pub fn generate_fourpl_trajectory(&self, time_start: f64, num_samples: usize) -> trajectory::FourPLTrajectory{
        let time_goal = self.get_max_time();
        let x = trajectory::FourPLTrajectorySingle::from_solution(&self.x, time_start, time_goal, num_samples);
        let y = trajectory::FourPLTrajectorySingle::from_solution(&self.y, time_start, time_goal, num_samples);
        let z = trajectory::FourPLTrajectorySingle::from_solution(&self.z, time_start, time_goal, num_samples);
        trajectory::FourPLTrajectory{x, y ,z}
    }
}

/// A container for all the [`MPSolutionElem`](MPSolutionElem)
#[derive(Clone, Debug)]
pub struct MPSolution {
    pub solutions: Vec<MPSolutionElem>
}
impl MPSolution {
    pub fn new() -> Self {
        Self {solutions: Vec::new()}
    }

    pub fn from_vec(solutions: Vec<MPSolutionElem>) -> Self {
        Self {solutions}
    }

    /// Add a new solution
    pub fn push(&mut self, sol: MPSolutionElem) {
        self.solutions.push(sol);
    }

    /// Get all solution positions (x, y, z) as a 2d array
    pub fn get_positions_as_array(&self) -> Array2<f64> {
        let mut a = Array::zeros((self.solutions.len(), 6));
        self.solutions.iter().enumerate().for_each(|(i, val)| {
            a.row_mut(i).assign(&val.get_positions_as_array());
        });
        a
    }

    /// Get all solutions as a vector of (x, y, z) component positions
    pub fn get_positions(&self) -> Vec<MPCompPositions> {
        self.solutions.iter().map(|x| x.get_position()).collect()
    }

    /// Calculate the max times for each solution and return as vector
    pub fn get_max_times(&self) -> Vec<f64> {
        self.solutions.iter().map(|x| x.get_max_time()).collect()
    }

    /// Number of solutions
    pub fn len(&self) -> usize {
        self.solutions.len()
    }

    /// Compute all of the trajectories of all of the solutions in one go
    /// - `time_start`: the time at which the trajectory starts
    /// - `num_samples`: the number of samples to take of the trajectory, i.e. the resolution
    pub fn compute_trajectories(&self, time_start: f64, num_samples: usize) -> trajectory::FourPLTrajectories {
        let mut trajs = trajectory::FourPLTrajectories::new(num_samples);
        self.solutions.iter().for_each(|x| {
            let t = x.generate_fourpl_trajectory(time_start, num_samples);
            trajs.push(t);
        });
        trajs
    }

    /// Compute trajectories for a random subset of all of the solutions in one go
    /// - `time_start`: the time at which the trajectory starts
    /// - `num_samples`: the number of samples to take of the trajectory, i.e. the resolution
    /// - `num_trajectories`: the number of trajectories to be chosen out
    pub fn compute_trajectories_random_sample(&self, time_start: f64, num_samples: usize, num_trajectories: usize) -> trajectory::FourPLTrajectories {
        let mut rng = thread_rng();
        let mut trajs = trajectory::FourPLTrajectories::new(num_samples);
        self.solutions.iter().choose_multiple(&mut rng, num_trajectories).iter().for_each(|x| {
            let t = x.generate_fourpl_trajectory(time_start, num_samples);
            trajs.push(t);
        });
        trajs
    }

    /// Returns a trajectory computation iterator based on the set of computed solutions.
    pub fn into_traj_iter(self, time_start: f64, num_samples: usize) -> trajectory::FourPLTrajectoryGeneratorIterator {
        trajectory::FourPLTrajectoryGeneratorIterator::new(self.solutions, time_start, num_samples)
    }
}
impl IntoIterator for MPSolution {
    type Item = MPSolutionElem;
    type IntoIter = std::vec::IntoIter<Self::Item>;

    fn into_iter(self) -> Self::IntoIter {
        self.solutions.into_iter()
    }
}

/// Defines an axis component
#[pyclass(name="Component")]
// #[derive(FromPyObject)]
#[derive(Debug, Clone)]
pub struct MPComponent {
    /// id for reference
    pub id: String,
    /// starting position
    pub pos: f64,
    /// minimum goal position
    pub pmin: f64,
    /// maximum goal position
    pub pmax: f64,
    /// starting velocity
    pub vel: f64,
    /// velcoity minimum bounds
    pub vmin: f64,
    /// velocity maximum bounds
    pub vmax: f64,
    /// starting acceleration
    pub accel:f64,
    /// accel minimum bounds
    pub amin: f64,
    /// accel maximum bounds
    pub amax: f64,
    /// jerk minimum bounds
    pub jmin: f64,
    /// jerk maximum bounds
    pub jmax: f64
}
impl MPComponent {
    /// Creates new component element with default values for derivative bounds
    /// - vmin: -5.0, vmax: 5.0,
    /// - amin: -10.0, amax: 10.0,
    /// - jmin: -50.0, jmax: 50.0
    pub fn new(
        id: String,
        pos: f64,
        pmin: f64,
        pmax: f64,
        vel: f64,
        accel:f64,
    ) -> MPComponent {
        MPComponent {
            id,
            pos, pmin, pmax, vel, accel,
            vmin: -5.0, vmax: 5.0,
            amin: -10.0, amax: 10.0,
            jmin: -50.0, jmax: 50.0
        }
    }
    pub fn set_vrange(&mut self, vmin: f64,vmax: f64) {
        self.vmin = vmin;
        self.vmax = vmax;
    }

    pub fn set_arange(&mut self, amin: f64,amax: f64) {
        self.amin = amin;
        self.amax = amax;
    }

    pub fn set_jrange(&mut self, jmin: f64, jmax: f64) {
        self.jmin = jmin;
        self.jmax = jmax;
    }
}

/// Defining a single componenet's solution set with obstacle avoidance computation
/// Equivalent to an R cell elem in the matlab implementation
#[derive(Clone, Debug)]
pub struct MPComponentSolutionElem {
    /// A particular solution set which satisfies the vel, accel and jerk bounds
    pub ss: ss::SolutionSetVectorElem,
    /// A list of intersection points and flags, one for each obstacle
    pub intersects: Vec<objects::Intersection>,
    /// A summary flag based on intersects detailing if
    /// 1. solution has missed all objects
    /// 2. solution has hit all objects
    /// 3. solution hits some objects
    pub flag: objects::IntersectionFlag
}
impl MPComponentSolutionElem {
    /// Generates a blank `MPComponentSolutionElem` from a given solution set (no obstacle avoidance is performed)
    pub fn from_ss(ss: ss::SolutionSetVectorElem) -> Self {
        Self { ss, intersects: Vec::new(), flag: objects::IntersectionFlag::Partial}
    }

    /// Generates a `MPComponentSolutionElem` through obstacle avoidance of a given solution and the set of objects along a given axis
    /// For each object, the given solution checks possible intersections and returns and [`Intersection`](objects::Intersection) struct.
    /// The summary flag is then updated
    fn from_ss_with_object_avoidance(time:f64, ss: ss::SolutionSetVectorElem, objaxis: &objects::ObjectAxisList) -> Self {
        log::debug!("ss: {:?}", ss);
        let mut out = Self::from_ss(ss);
        for obj in objaxis.objects.iter() {
            let intersect = ss.compute_object_intersections(obj, time);
            out.intersects.push(intersect);
        }
        if out.intersects.iter().all(|x| x.flag == objects::IntersectionFlag::None) {
            out.flag = objects::IntersectionFlag::None; // Solution misses all objects
        } else if out.intersects.iter().all(|x| x.flag == objects::IntersectionFlag::Complete) {
            out.flag = objects::IntersectionFlag::Complete; // Solution hits all objects
        }
        out
    }

    pub fn is_all_missed(&self) -> bool {
        self.flag  == objects::IntersectionFlag::None
    }

    pub fn get_all_flags(&self) -> Vec<objects::IntersectionFlag> {
        self.intersects.iter().map(|x| x.flag).collect()
    }

    /// Return flags as array of numbers for heuristic method
    pub fn get_all_flags_as_array(&self) -> Array1<u8> {
        Array::from_vec(self.intersects.iter().map(|x| x.flag.to_num()).collect())
    }
}

/// A container for the individual solutions of a given axis component [`MPComponentSolutionElem`](MPComponentSolutionElem)
#[derive(Clone, Debug)]
pub struct MPComponentSolution {
    pub solutions: Vec<MPComponentSolutionElem>
}

impl MPComponentSolution {
    pub fn new() -> Self{
        Self {solutions: Vec::new()}
    }

    /// Generate an MPComponent solution from solution set vectors
    pub fn from_ss(ssv: ss::SolutionSetVectors) -> Self {
        let mut s = Self::new();
        ssv.into_iter().for_each(|ss| {
            s.solutions.push(MPComponentSolutionElem::from_ss(ss))
        });
        s
    }

    /// Generate an MPComponent Solution from solution set vectors
    pub fn from_ss_with_object_avoidance(time: f64, ssv: ss::SolutionSetVectors, objectaxis: &objects::ObjectAxisList) -> Self {
        let mut s = Self::new();
        ssv.into_iter().for_each(|ss| {
            s.solutions.push(MPComponentSolutionElem::from_ss_with_object_avoidance(time, ss, &objectaxis))
        });
        s
    }

    pub fn check_any_missed(&self) -> bool {
        self.solutions.iter().any(|x| x.is_all_missed())
    }

    pub fn get_all_missed_solutions(&self) -> MPComponentSolution {
        MPComponentSolution {
            solutions: self.solutions.iter().filter(|x| x.is_all_missed()).map(|x| x.clone()).collect()
        }
    }

    /// Return solutions as a 2D array, solutions.len() x 3, containing only the start and end position and end time.
    pub fn into_xi_xg_tg_array(&self) -> Array2<f64> {
        let mut a = Array::zeros((self.solutions.len(), 3));
        self.solutions.iter().enumerate().for_each(|(i, val)| {
            a.row_mut(i).assign(&val.ss.to_xg_xi_tg());
        });
        a
    }

    /// Generate flags as a 2D array
    pub fn get_flags_as_2d_array(&self) -> Array2<u8> {
        if self.solutions.is_empty() {
            return Array::zeros((0,0))
        }
        let mut a = Array::zeros((self.solutions.len(), self.solutions[0].intersects.len()));
        self.solutions.iter().enumerate().for_each(|(i, val)| {
            a.row_mut(i).assign(&val.get_all_flags_as_array());
        });
        a
    }

    /// Inplace random shuffle of the solutions.
    pub fn random_shuffle(&mut self) {
        self.solutions.shuffle(&mut thread_rng());
    }

}

/// Holds the component solutions (x, y, z) axis together.
#[derive(Clone, Debug)]
pub struct MPComponentSolutions {
    pub xsol: MPComponentSolution,
    pub ysol: MPComponentSolution,
    pub zsol: MPComponentSolution,
}
impl MPComponentSolutions {
    pub fn new(
        xsol: MPComponentSolution,
        ysol: MPComponentSolution,
        zsol: MPComponentSolution) -> Self {
        Self {xsol, ysol, zsol}
    }

    pub fn random_shuffle(&mut self) {
        self.xsol.random_shuffle();
        self.ysol.random_shuffle();
        self.zsol.random_shuffle();
    }
}