/*
MIT License

Copyright (c) 2021 University of Bristol Flight Laboratory

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Author: Mickey Li <mickeyhli@outlook.com>
*/

//! Module containing structures and algorithms for computing the solution set for a given axis component.
use ndarray::{Array, Array1, Array3, Array2, Axis, stack, array};
use num::complex::Complex64;
use itertools::{Itertools, izip};
use ordered_float::{OrderedFloat};
use crate::mp::{MPComponent, MPParams, objects};

/// Builder for the core elements of the Component Solution Set
///
/// First calculates xgs and xis which are all combinations of initial and final positions given the grid sample parameter dt
/// The B and C matrix are then calculated for each of those combinations
/// The start and end trajectory times are also calculated for each of those combinations
///
/// Recommended use which builds all of the parameters are
/// ```
/// let builder = ComponentSolutionSetBuilder::new(time, comp, params).calculate_parameters();
/// builder.build()
/// ```
struct ComponentSolutionSetBuilder<'a> {
    time: f64, comp: &'a MPComponent,
    dt: usize, epsilon: f64,
    xgs:    Option<Array2<f64>>,
    xis:    Option<Array2<f64>>,
    b_mat:  Option<Array2<f64>>,
    c_mat:  Option<Array2<f64>>,
    td:     Option<Array2<f64>>,
    tg:     Option<Array2<f64>>,
}
impl<'a> ComponentSolutionSetBuilder<'a> {
    pub fn new(time: f64, comp: &'a MPComponent, params: &MPParams) -> Self{
        Self {
            time, comp,
            dt: params.dt,
            epsilon: params.epsilon,
            xgs:  None,
            xis:  None,
            b_mat:None,
            c_mat:None,
            td:   None,
            tg:   None
        }
    }

    pub fn build(self) -> ComponentSolutionSet<'a> {
        ComponentSolutionSet {
            comp: self.comp,
            time : self.time,
            xgs   : self.xgs.unwrap(),
            xis   : self.xis.unwrap(),
            b_mat : self.b_mat.unwrap(),
            c_mat : self.c_mat.unwrap(),
            td    : self.td.unwrap(),
            tg    : self.tg.unwrap(),
        }
    }

    fn calculate_parameters(self) -> Self{
        self.calculate_xis()
        .calculate_xgs()
        .calculate_b_matrix()
        .calculate_c_matrix()
        .calculate_td()
        .calculate_tg()
    }

    fn calculate_xis(mut self) -> Self {
        let comp = &self.comp; let dt = self.dt;
        // X range bottom 1D array dt
        let xgs_1d = Array::linspace(comp.pmin, comp.pmax, dt);

        // Max x's such that B > 2, 1D array dt
        let ximax =
        ((comp.vel.powf(2.0) * (8.0 * comp.pos - 3.0 * &xgs_1d))
            + (4.0 * comp.accel * comp.pos * (&xgs_1d - comp.pos)))
        / (5.0 * comp.vel.powf(2.0)
            - (4.0 * comp.accel * (comp.pos - &xgs_1d)));
        let ximin =
        ((comp.accel * comp.pos * (&xgs_1d - comp.pos))
            - (&xgs_1d * comp.vel.powf(2.0))
            + (2.0 * comp.pos * comp.vel.powf(2.0)))
        / ( comp.vel.powf(2.0) + (comp.accel * (&xgs_1d - comp.pos)));

        // println!("ximax: {}\nximin: {}", &ximax, &ximin);

        let mut xi = Array::<f64, _>::zeros((dt, dt));
        for i in 0..dt {
            let k = Array::linspace(ximin[i]+0.01, ximax[i], dt);
            xi.row_mut(i).assign(&k);
        }
        self.xis = Some(xi);
        self
    }

    fn calculate_xgs(mut self) -> Self {
        let xis = self.xis.as_ref().unwrap();
        // X range bottom 1D array dt
        let xgs_1d = Array::linspace(self.comp.pmin, self.comp.pmax, self.dt);

        // Make xgs same shape as xis. Take ownership of xgs_1d
        let xgs = xgs_1d.insert_axis(Axis(1))
            * Array::<f64, _>::ones(xis.raw_dim());
        self.xgs = Some(xgs);
        self
    }

    /// Calculate B matrix, dt x dt
    /// Bx = (vs.^2 .* (xg-xi)) ./ ((2 .* vs.^2 .* (xg + xi - 2.*xs)) - (2 .* as .* (xg-xs).*(xs-xi)));
    fn calculate_b_matrix(mut self) -> Self {
        let comp = &self.comp;
        let xgs = self.xgs.as_ref().unwrap();
        let xis = self.xis.as_ref().unwrap();
        let b = (comp.vel.powf(2.0) * (xgs - xis))
            / ((2.0 * comp.vel.powf(2.0) * (xgs + xis - 2.0 * comp.pos))
                - (2.0 * comp.accel * (xgs - comp.pos) * (comp.pos - xis)));
        self.b_mat = Some(b);
        self
    }

    /// Calculate C matrix, dt x dt
    /// Cx = abs(2.*Bx.*(xg - xs).*(xs-xi).*(((xg - xs)./(xs-xi)).^(1./(2.*Bx)))) ./ (vs .* (xg-xi));
    fn calculate_c_matrix(mut self) -> Self {
        let comp = &self.comp;
        let xgs = self.xgs.as_ref().unwrap();
        let xis = self.xis.as_ref().unwrap();
        let b_mat = self.b_mat.as_ref().unwrap();
        let mut c_num_1 = (xgs - comp.pos) / (comp.pos - xis);
        c_num_1.zip_mut_with(b_mat, |g, b| {*g = g.powf(1.0 / (2.0 * b)); } );
        let c_num = 2.0 * b_mat * (xgs - comp.pos)
                        * (comp.pos - xis) * c_num_1;
        let c = c_num.mapv(f64::abs) / (comp.vel * (xgs - xis));
        self.c_mat = Some(c);
        self
    }

    /// Calculate td, dt x dt
    /// td = (Cx .* (((xs-xi)./(xg-xs)).^(1./(2.*Bx))));
    fn calculate_td(mut self) -> Self {
        let comp = &self.comp;
        let xgs = self.xgs.as_ref().unwrap();
        let xis = self.xis.as_ref().unwrap();
        let b_mat = self.b_mat.as_ref().unwrap();
        let c_mat = self.c_mat.as_ref().unwrap();
        let mut td_temp = (comp.pos - xis) / (xgs - comp.pos);
        td_temp.zip_mut_with(b_mat, |g, b| {*g = g.powf(1.0 / (2.0 * b)); });
        self.td = Some(c_mat * td_temp);
        self
    }

    /// Calculate tg, dt x dt
    /// tg = ts - td + (abs(Cx) .* ((abs(xg-xi)-eps)./eps).^(1./(2.*Bx)));
    fn calculate_tg(mut self) -> Self {
        let xgs = self.xgs.as_ref().unwrap();
        let xis = self.xis.as_ref().unwrap();
        let b_mat = self.b_mat.as_ref().unwrap();
        let c_mat = self.c_mat.as_ref().unwrap();
        let td = self.td.as_ref().unwrap();
        let mut tg_temp = ((xgs - xis).mapv(f64::abs) - self.epsilon) / self.epsilon;
        tg_temp.zip_mut_with(b_mat, |g, b| {*g = g.powf(1.0 / (2.0 * b)); });
        self.tg = Some(self.time - td + (c_mat.mapv(f64::abs) * tg_temp));
        self
    }
}

/// A constructed solution set based on a given component
pub struct ComponentSolutionSet<'a> {
    comp: &'a MPComponent,
    xgs: Array2<f64>,
    xis: Array2<f64>,
    b_mat: Array2<f64>,
    c_mat: Array2<f64>,
    td: Array2<f64>,
    tg: Array2<f64>,
    time: f64,
}
impl<'a> ComponentSolutionSet<'a> {
    /// Constructs the solution set search space given the time at search, the component and the solver parameters
    pub fn new(time: f64, comp: &'a MPComponent, params: &MPParams) -> ComponentSolutionSet<'a>{
        // calculate and assign xis, xgs, b_mat, c_mat, td, tg for a given set of parameters
        let builder = ComponentSolutionSetBuilder::new(time, comp, params).calculate_parameters();
        builder.build()
    }

    /// Generates a unique list of solutions which produce feasible trajectories and satisify the bounds on velocity, acceleration and jerk
    /// First calculates the vel, accel and jerk bounds for the given solution set,
    /// Then selects the solutions which satisfy these bounds.
    pub fn generate_feasible_solutions(&self) -> SolutionSetVectors{
        let vmax = self.calculate_velocity_bounds();
        let amax = self.calculate_accel_bounds();
        let jmax = self.calculate_jerk_bounds();

        let mask1 = &self.b_mat.mapv(|x|x>= 2.) & &self.b_mat.mapv(|x|x < 100.);
        let mask2 = vmax.mapv(|x| x < self.comp.vmax && x > self.comp.vmin)
                    .outer_iter().fold(Array::from_elem(mask1.raw_dim(), true), |a, b| a & b);
        let mask3 = amax.mapv(|x| x < self.comp.amax && x > self.comp.amin)
                    .outer_iter().fold(Array::from_elem(mask1.raw_dim(), true), |a, b| a & b);
        let mask4 = jmax.mapv(|x| x < self.comp.jmax && x > self.comp.jmin)
                    .outer_iter().fold(Array::from_elem(mask1.raw_dim(), true), |a, b| a & b);
        let mask = mask1 & mask2 & mask3 & mask4;

        let xg = filter_by_mask(&self.xgs, &mask);
        let xi = filter_by_mask(&self.xis, &mask);
        let bs = filter_by_mask(&self.b_mat, &mask);
        let cs = filter_by_mask(&self.c_mat, &mask);
        let td = filter_by_mask(&self.td, &mask);
        let tg = filter_by_mask(&self.tg, &mask);

        SolutionSetVectors {
            vecs: izip!(xi, xg, bs, cs, td, tg).map(|x|
                    SolutionSetVectorElem {
                        xi: x.0,
                        xg: x.1,
                        bs: x.2,
                        cs: x.3,
                        td: x.4,
                        tg: x.5
                    }).collect()
        }.unique()
    }

    /// Calculate bounds on velocity, output is (dt x dt x 2)
    /// See paper and SS_1D_19feb2021_v2.m: 80ish for more details
    fn calculate_velocity_bounds(&self) -> Array3<f64>{
        let plus = Array::<f64, _>::ones(self.xgs.raw_dim());
        let minus = -1.0 * Array::<f64, _>::ones(self.xgs.raw_dim());
        let pm = stack(Axis(0), &[plus.view(), minus.view()]).unwrap();

        // Velocity Bounds
        let mut tv_temp1 = (2.0 * &self.b_mat - 1.0)/(2.0 * &self.b_mat + 1.0);
        tv_temp1.zip_mut_with(&self.b_mat, |g, b| {*g = g.powf(1.0 / (2.0 * b)); });
        let tv_temp2 = &self.c_mat * tv_temp1;
        let tv = self.time - &self.td + pm * (tv_temp2);

        let mut vmax_temp11 = (&tv - self.time + &self.td) / &self.c_mat;
        vmax_temp11.outer_iter_mut().for_each(|mut av| {
                    av.zip_mut_with(&self.b_mat, |g, &b| {
                        *g = g.powf(2.0);
                        *g = g.powf(b);
                    });
                }); // confirmed working
        let mut vmax = (2.0 * &self.b_mat * (&self.xgs - &self.xis) * &vmax_temp11)
                / ((&tv - self.time + &self.td) * (1.0 + &vmax_temp11).mapv(|x| x.powf(2.0)));
        vmax.zip_mut_with(&tv, |v, &_tv| {if _tv < self.time {*v = 0.0;}});

        log::debug!("vmax shape: {:?}\n{}", &vmax.shape(), &vmax); //confirmed matching
        vmax
    }

    /// Calculate bounds on acceleration, output is (dt x dt x 4)
    /// See paper and SS_1D_19feb2021_v2.m: 80ish for more details
    fn calculate_accel_bounds(&self) -> Array3<f64>{
        let b_mat = &self.b_mat;
        let c_mat = &self.c_mat;
        let plus = Array::<f64, _>::ones(self.xgs.raw_dim());
        let minus = -1.0 * Array::<f64, _>::ones(self.xgs.raw_dim());
        // Acceleration Bounds
        let f1 = (4.0 * b_mat + 2.0) * (b_mat + 1.0);
        let f2 = 4.0 - (16.0 * b_mat.mapv(|x| x.powf(2.0)));
        let f3 = 2.0 * (b_mat - 1.0) * (2.0 * b_mat - 1.0);
        let pm1 = stack(Axis(0), &[plus.view(), minus.view(),
                        plus.view(), minus.view()]).unwrap();
        let pm2 = stack(Axis(0), &[plus.view(), plus.view(),
                        minus.view(), minus.view()]).unwrap();

        let ta_inner = (&f2.mapv(|x| x.powf(2.0))
                                - (4.0 * &f1 * &f3)).mapv(f64::sqrt);
        let ta_inner2 = (-&f2 + pm1 * ta_inner)/(2.0 * &f1);
        let ta = self.time - &self.td + pm2 * c_mat * pow_1div2b(ta_inner2, b_mat);

        let mut amax_inner = (&ta - self.time + &self.td) / c_mat;
        amax_inner.outer_iter_mut().for_each(|mut av| {
            av.zip_mut_with(b_mat, |g, &b| {
                *g = g.powf(2.0);
                *g = g.powf(b);
            })
        });
        let amax_outer_1 = (2.0 * b_mat * (&self.xgs - &self.xis) * &amax_inner)
            / ((&ta - self.time + &self.td).mapv(|x|x.powf(2.0))
                * (1.0 + &amax_inner).mapv(|x|x.powf(2.0)));
        let amax_outer_2 = (2.0 * b_mat - 1.0)
            - (4.0 * b_mat * &amax_inner) / (1.0 + &amax_inner);
        let mut amax = amax_outer_1 * amax_outer_2;
        amax.zip_mut_with(&ta, |a, &_ta| {if _ta < self.time {*a = 0.0;}});

        log::debug!("amax shape: {:?}\n{}", &amax.shape(), &amax); //confirmed matching
        amax
    }

    /// Calculate bounds on jerk, output is (dt x dt x 7)
    /// See paper and SS_1D_19feb2021_v2.m: 80ish for more details
    fn calculate_jerk_bounds(&self) -> Array3<f64>{
        let b_mat = &self.b_mat;
        let c_mat = &self.c_mat;
        let td = &self.td;
        let plus = Array::<f64, _>::ones(self.xgs.raw_dim());
        let minus = -1.0 * Array::<f64, _>::ones(self.xgs.raw_dim());

        // Jerk Bounds
        let bj1 = 96.0 * pow(b_mat, 3.);
        let bj2 = 72.0 * pow(b_mat , 2.) * (2.0 * b_mat - 1.);
        let bj3 = b_mat * (2.0 * b_mat - 1.0) * (28.0 * b_mat - 22.0);
        let bj4 = (2.0 * b_mat - 3.0) * (2.0 * b_mat - 1.0) * (b_mat - 1.0);
        let aj = &bj1 - &bj2 + &bj3 - &bj4;
        let bj = -&bj2 + 2.0 * &bj3 - 3.0 * &bj4;
        let cj = &bj3 - 3.0*&bj4;
        let dj = -1.0 * &bj4;

        let qj = ((3. * &aj * &cj) - pow(&bj, 2.))
                / (9. * pow(&aj, 2.0));
        let rj = ((9. * &aj * &bj * &cj)
                    - (27. * pow(&aj, 2.) * &dj)
                    - (2. * pow(&bj, 3.)))
                / (54. * pow(&aj, 3.));
        let qjc = qj.mapv(Complex64::from); // next few computations known to go complex
        let rjc = rj.mapv(Complex64::from); // next few computations known to go complex
        let sjc = (&rjc + (&rjc.mapv(|x|x.powf(2.0)) + &qjc.mapv(|x| x.powf(3.0))).mapv(|x| x.sqrt())).mapv(|x|x.powf(1./3.));
        let vjc = (&rjc - (&rjc.mapv(|x|x.powf(2.0)) + &qjc.mapv(|x| x.powf(3.0))).mapv(|x| x.sqrt())).mapv(|x|x.powf(1./3.));

        let pmc = stack(Axis(0), &[plus.view(), minus.view()]).unwrap().mapv(Complex64::from);
        let pm1c = stack(Axis(0), &[plus.view(), minus.view(),
                        plus.view(), minus.view()]).unwrap().mapv(Complex64::from);
        let pm2c = stack(Axis(0), &[plus.view(), plus.view(),
                        minus.view(), minus.view()]).unwrap().mapv(Complex64::from);

        let c_matc = c_mat.mapv(Complex64::from);
        let tdc = td.mapv(Complex64::from);
        let timec = Complex64::from(self.time);

        let mut tj14_tempc = (&sjc + &vjc - (&bj/(3. * &aj))).mapv(Complex64::from);
        tj14_tempc.zip_mut_with(b_mat, |g, b| {*g = g.powf(1.0 / (2.0 * b)); });
        let tj14c = timec - &tdc + pmc * &c_matc  * tj14_tempc;

        let mut tj2356_tempc = -((&sjc + &vjc)/2.) - (&bj / (3.*&aj)) + pm1c * (Complex64::from(-3.0).sqrt() * (&sjc - &vjc)/2.);
        tj2356_tempc.zip_mut_with(b_mat, |g, b| {*g = g.powf(1.0 / (2.0 * b)); });
        let tj2356c = timec - &tdc + pm2c * &c_matc * tj2356_tempc;
        let tj14 = tj14c.mapv(|x| x.re); // go back to f64 real
        let tj2356 = tj2356c.mapv(|x| x.re); // go back to f64 real

        // Need to be included to remove |J(ts)| > j_bnd values, that is intial jerk value violating bound
        let tjs = Array::from_elem(b_mat.raw_dim(), self.time);

        let tj = stack!(Axis(0),
            tj14.index_axis(Axis(0), 0),
            tj2356.index_axis(Axis(0), 0),
            tj2356.index_axis(Axis(0), 1),
            tj14.index_axis(Axis(0), 1),
            tj2356.index_axis(Axis(0), 2),
            tj2356.index_axis(Axis(0), 3),
            tjs
        ); // confirmed

        // println!("tj shape: {:?}\n{}", &tj.shape(), &tj);

        let mut jterm1_temp = (&tj - self.time + td)/c_mat ;
        jterm1_temp.outer_iter_mut().for_each(|mut av| {
            av.zip_mut_with(b_mat, |g, &b| {
                *g = g.powf(2.0);
                *g = g.powf(b);
            })
        });
        let jterm1 = (2. * b_mat * (&self.xgs - &self.xis) * &jterm1_temp)
                / (pow(&(&tj - self.time + td), 3.) * pow(&(1. + &jterm1_temp), 2.));
        let mut jterm2_temp = (&tj - self.time + td)/c_mat;
        jterm2_temp.outer_iter_mut().for_each(|mut av| {
            av.zip_mut_with(b_mat, |g, &b| {
                *g = g.powf(4.0);
                *g = g.powf(b);
            })
        });
        let jterm2 = (24. * pow(b_mat, 2.0) * &jterm2_temp)
                /  pow(&(1. + &jterm1_temp), 2.);
        let jterm3 = (12. * b_mat * (2. * b_mat - 1.) * &jterm1_temp)
                /  (1. + &jterm1_temp);
        let jterm4 = (2.* b_mat-1.) * (2.*b_mat - 2.);
        let mut jmax = jterm1 * (jterm2 - jterm3 + jterm4);
        jmax.zip_mut_with(&tj, |j, &_tj| {if _tj < self.time {*j = 0.0;}});

        log::debug!("jmax shape: {:?}\n{}", &jmax.shape(), &jmax);
        jmax
    }
}

/// Repesentation of a single component solution
#[derive(Clone, Copy, Debug)]
pub struct SolutionSetVectorElem {
    pub xi:  f64,
    pub xg:  f64,
    pub bs:  f64,
    pub cs:  f64,
    pub td:  f64,
    pub tg:  f64,
}
impl SolutionSetVectorElem {
    fn to_compare(self) -> (OrderedFloat<f64>, OrderedFloat<f64>) {
        (OrderedFloat(self.xi), OrderedFloat(self.xg))
    }

    /// Returns an array only containing xi, xg and tg
    pub fn to_xg_xi_tg(self) -> Array1<f64> {
        array![self.xi, self.xg, self.tg]
    }

    pub fn to_vec(self) -> Vec<f64> {
        vec![self.xi, self.xg, self.bs, self.cs, self.td, self.tg]
    }

    /// Given an object, computes the intersection with itself
    /// See the paper and [`Intersection`](objects::Intersection) for more details.
    /// Essentially it performs a number of tests to determine which of type of intersection we get.
    pub fn compute_object_intersections(&self, obj: &objects::ObjectAxis, time: f64) -> objects::Intersection {
        let obs_pos = array![obj.bounds.l, obj.bounds.u];
        let calc_inner1_c = ((&obs_pos - self.xi) / (self.xg - &obs_pos)).mapv(Complex64::from);
        log::debug!("calc_inner1_c 1: {}", calc_inner1_c);
        let calc_inner2_c = Complex64::from(self.cs) * calc_inner1_c.mapv(|x| {
            log::debug!("calc_inner2_c x 1: {}", x);
            if x.re.is_infinite() {
                x // case where x was infinity, when square rooted it becomes Nan rather than Inf
            } else {
                x.powf(1.0/ (2.*self.bs))
            }
        });
        log::debug!("calc_inner2_c 1: {}", calc_inner2_c);
        let to_minus_c = time - Complex64::from(self.td) - &calc_inner2_c; // [l, u]
        let to_plus_c = time - Complex64::from(self.td) + &calc_inner2_c; // [l, u]
        let to = array![to_minus_c[1], to_minus_c[0], to_plus_c[0], to_plus_c[1]];

        let test1 = to.iter().all(|&x| x.im.abs() > 1e-5);
        let test2 = to[0].im.abs() < 1e-5 && to[1].im.abs() > 1e-5 && to[2].im.abs() > 1e-5 && to[3].im.abs() < 1e-5;
        let test3 = to[0].im.abs() > 1e-5 && to[1].im.abs() < 1e-5 && to[2].im.abs() < 1e-5 && to[3].im.abs() > 1e-5;

        let mut intersect = objects::Intersection::new();

        if test1 {
            let min_obs_pos = obs_pos[0].min(obs_pos[1]);
            let max_obs_pos = obs_pos[0].max(obs_pos[1]);
            let max_ss_x = self.xg.max(self.xi);
            let min_ss_x = self.xg.min(self.xi);
            if (min_obs_pos >= max_ss_x) || (max_obs_pos <= min_ss_x) {
                intersect.flag = objects::IntersectionFlag::None;
            } else if (min_obs_pos < min_ss_x) && (max_obs_pos > max_ss_x) {
                intersect.points.tkr_u = f64::INFINITY;
                intersect.flag = objects::IntersectionFlag::Complete;
            } else {
                panic!("All to imaginary but timelims exceeded");
            }
        } else if test2 { // Assumed 0 and 3 are real
            let max_to = to[0].re.max(to[3].re);
            let min_to = to[0].re.min(to[3].re);
            if max_to < time {
                intersect.flag = objects::IntersectionFlag::None;
            } else if max_to >= time && min_to <= time {
                intersect.points.tkr_u = time;
                intersect.points.tkr = max_to;
                intersect.flag = objects::IntersectionFlag::Partial;
            } else if min_to > time {
                intersect.points.tkr_u = min_to;
                intersect.points.tkr = max_to;
                intersect.flag = objects::IntersectionFlag::Partial;
            } else {
                panic!("Test2 checked but all missed, should not get here");
            }
        } else if test3 { // Assumed 1 and 2 are real
            let max_to = to[1].re.max(to[2].re);
            let min_to = to[1].re.min(to[2].re);
            if min_to > time {
                intersect.points.tkl = time;
                intersect.points.tkl_u = min_to;
                intersect.points.tkr_u = max_to;
                intersect.points.tkr = f64::INFINITY;
                intersect.flag = objects::IntersectionFlag::Partial;
            } else if max_to < time {
                intersect.points.tkr_u = time;
                intersect.points.tkr = f64::INFINITY;
                intersect.flag = objects::IntersectionFlag::Partial;
            } else {
                intersect.points.tkr_u = max_to; intersect.points.tkr = f64::INFINITY;
                intersect.flag = objects::IntersectionFlag::Partial;
            }
        } else { // All are real
            let mut to_ord = to.mapv(|x|x.re).into_raw_vec();
            to_ord.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less));
            if time < to_ord[0] {
                intersect.points.set_arr(to_ord);
                intersect.flag = objects::IntersectionFlag::Partial;
            } else if (time >= to_ord[0]) && (time <= to_ord[1]) {
                to_ord[0] = time;
                intersect.points.set_arr(to_ord);
                intersect.flag = objects::IntersectionFlag::Partial;
            } else if (time >= to_ord[1]) && (time <= to_ord[2]) {
                intersect.points.tkr_u = to_ord[2];
                intersect.points.tkr = to_ord[3];
                intersect.flag = objects::IntersectionFlag::Partial;
            } else if (time >= to_ord[2]) && (time <= to_ord[3]) {
                intersect.points.tkr_u = time;
                intersect.points.tkr = to_ord[3];
                intersect.flag = objects::IntersectionFlag::Partial;
            } else if time > to_ord[3] {
                intersect.flag = objects::IntersectionFlag::Complete
            } else {
                panic!("This statement should not be reached");
            }
        }
        intersect
    }
}

/// A container for the [`SolutionSetVectorElem`](SolutionSetVectorElem)
#[derive(Debug)]
pub struct SolutionSetVectors {
    vecs: Vec<SolutionSetVectorElem>
}

impl SolutionSetVectors {
    pub fn unique(mut self) -> Self {
        self.vecs = self.vecs.into_iter().unique_by(|x| x.to_compare()).collect();
        log::debug!("{:?}", &self.vecs);
        self
    }
}

impl IntoIterator for SolutionSetVectors {
    type Item = SolutionSetVectorElem;
    type IntoIter = std::vec::IntoIter<Self::Item>;
    fn into_iter(self) -> Self::IntoIter {
        self.vecs.into_iter()
    }
}

//////////////////////
// Helper functions //
//////////////////////

/// Helper function which computes a matrix to the power of 1/(2*B), where B is also a matrix
pub fn pow_1div2b<E>(mut mat: Array<f64, E>, b_mat: &Array2<f64>) -> Array<f64, E>
    where E: ndarray::Dimension,
{
    mat.zip_mut_with(b_mat, |g, b| {*g = g.powf(1.0 / (2.0 * b)); });
    mat
}

fn pow<D>(a: &Array<f64, D>, d: f64) -> Array<f64, D>
    where D: ndarray::Dimension
{
    a.mapv(|x| x.powf(d))
}

fn filter_by_mask<D>(a: &Array<f64, D>, mask: &Array<bool, D>) -> Array1<f64>
    where D: ndarray::Dimension
{
    let t = a.iter().zip(mask.iter()).filter(|a| *a.1).map(|a| *a.0).collect();
    Array::from_vec(t)
}

#[cfg(test)]
mod tests {
    use crate::mp::{MPComponent, MPMethod};
    use super::*;
    use ndarray::{arr2};

    #[test]
    fn test_ss_builder1() {
        let time = 0.0;
        let comp = MPComponent::new("x".to_string(), 0.0, 4.0, 4.0, 3.0, 1.0);
        let params = MPParams {dt: 5, epsilon: 0.01, method: MPMethod::BRUTEFORCE};

        // calculate and assign xis, xgs, b_mat, c_mat, td, tg for a given set of parameters
        let builder = ComponentSolutionSetBuilder::new(time, &comp, &params).calculate_parameters();

        let xi = builder.xis.unwrap();
        let correct_xi = arr2(&[
            [-2.7592,  -2.5120,  -2.2649,  -2.0177 , -1.7705],
            [-2.7592,  -2.5120,  -2.2649,  -2.0177 , -1.7705],
            [-2.7592,  -2.5120,  -2.2649,  -2.0177 , -1.7705],
            [-2.7592,  -2.5120,  -2.2649,  -2.0177 , -1.7705],
            [-2.7592,  -2.5120,  -2.2649,  -2.0177 , -1.7705]
            ]);

        assert!(
            xi.abs_diff_eq(&correct_xi, 1e-4),
            "Calculated Xi was:\n{}\nvs\n{}", xi, correct_xi
        );

        let b_mat = builder.b_mat.unwrap();
        let correct_bmat = arr2(&[
            [233.9734,     8.7648     ,4.2996     ,2.7716     ,2.0000],
            [233.9734,     8.7648     ,4.2996     ,2.7716     ,2.0000],
            [233.9734,     8.7648     ,4.2996     ,2.7716     ,2.0000],
            [233.9734,     8.7648     ,4.2996     ,2.7716     ,2.0000],
            [233.9734,     8.7648     ,4.2996     ,2.7716     ,2.0000],
            ]);
        assert!(
            b_mat.abs_diff_eq(&correct_bmat, 1e-4),
            "Calculated b_mat was:\n{}\nvs\n{}", b_mat, correct_bmat
        );

        let tg = builder.tg.unwrap();
        let correct = arr2(&[
            [3.7757,   4.3813,   5.2182,   6.4142,   8.1921],
            [3.7757,   4.3813,   5.2182,   6.4142,   8.1921],
            [3.7757,   4.3813,   5.2182,   6.4142,   8.1921],
            [3.7757,   4.3813,   5.2182,   6.4142,   8.1921],
            [3.7757,   4.3813,   5.2182,   6.4142,   8.1921],
            ]);

        assert!(
            tg.abs_diff_eq(&correct, 1e-4),
            "Calculated b_mat was:\n{}\nvs\n{}", tg, correct
        );

    }

    #[test]
    fn test_ss_builder2() {
        let time = 0.0;
        let comp = MPComponent::new("y".to_string(), 1.5, 2.8, 3.2, 3.0, 0.5);
        let params = MPParams {dt: 5, epsilon: 0.01, method:MPMethod::BRUTEFORCE};

        // calculate and assign xis, xgs, b_mat, c_mat, td, tg for a given set of parameters
        let builder = ComponentSolutionSetBuilder::new(time, &comp, &params).calculate_parameters();

        let xi = builder.xis.unwrap();
        let correct_xi = arr2(&[
            [ 0.297565   ,0.413825   ,0.530085   ,0.646345  , 0.762605],
            [ 0.211031   ,0.335574   ,0.460118   ,0.584661  , 0.709205],
            [ 0.125385   ,0.258101   ,0.390817   ,0.523534  , 0.656250],
            [ 0.040612   ,0.181393   ,0.322173   ,0.462954  , 0.603734],
            [-0.043299   ,0.105439   ,0.254177   ,0.402915  , 0.551653],
            ]);

        assert!(
            xi.abs_diff_eq(&correct_xi, 1e-4),
            "Calculated Xi was:\n{}\nvs\n{}", xi, correct_xi
        );

        let b_mat = builder.b_mat.unwrap();
        let correct_bmat = arr2(&[
            [116.6939,     8.8130,     4.3646,     2.7992,     2.0000],
            [124.7460,     8.8424,     4.3688,     2.7999,     2.0000],
            [132.6746,     8.8672,     4.3722,     2.8004,     2.0000],
            [140.4821,     8.8883,     4.3748,     2.8008,     2.0000],
            [148.1710,     8.9062,     4.3769,     2.8010,     2.0000],
            ]);
        assert!(
            b_mat.abs_diff_eq(&correct_bmat, 1e-4),
            "Calculated b_mat was:\n{}\nvs\n{}", b_mat, correct_bmat
        );

        let tg = builder.tg.unwrap();
        let correct = arr2(&[
            [1.1794,   1.3139,   1.4938,   1.7425,   2.1004],
            [1.2836,   1.4332,   1.6336,   1.9115,   2.3130],
            [1.3886,   1.5534,   1.7750,   2.0831,   2.5295],
            [1.4941,   1.6747,   1.9179,   2.2570,   2.7497],
            [1.6001,   1.7967,   2.0622,   2.4331,   2.9735],
            ]);

        assert!(
            tg.abs_diff_eq(&correct, 1e-4),
            "Calculated b_mat was:\n{}\nvs\n{}", tg, correct
        );
    }


    #[test]
    fn test_ss_object_avoidance() {
        let objaxis = objects::ObjectAxis{
            to_init: f64::NEG_INFINITY,
            to_end: f64::INFINITY,
            bounds: objects::Bounds::new(3.75, 4.25)
        };
        let ssve = SolutionSetVectorElem {
            xi: -2.512046027742749,
            xg: 4.0,
            bs: 8.76478816663791,
            cs: 9.258604177348941,
            td: 9.016132545423174,
            tg: 4.381251903980132,
        };
        let inersects = ssve.compute_object_intersections(&objaxis, 0.0);

        let correct_intersects = objects::Intersection {
            points: objects::IntersectionPoints {
                tkl: f64::NEG_INFINITY,
                tkl_u: f64::NEG_INFINITY,
                tkr_u:  2.1099103208467245,
                tkr: f64::INFINITY
            },
            flag: objects::IntersectionFlag::Partial
        };

        approx::assert_relative_eq!(inersects, correct_intersects);
    }
}