/*
MIT License

Copyright (c) 2021 University of Bristol Flight Laboratory

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Author: Mickey Li <mickeyhli@outlook.com>
*/

//! Module containing structures and algorithms for combining the found component solutions
use crate::{mp, objects};

/// Trait which all solvers must implement
/// The solve function is so that a `solve()` can always be called
/// The IntoIterator is to ensure that an iterator can be created for lazy generation of solutions.
pub trait MPSolver: IntoIterator {
    fn solve(self) -> mp::MPSolution;
}

/// Brute Force Solver Structure and iterator
#[derive(Clone)]
pub struct BruteForceSolver {
    sols: mp::MPComponentSolutions,
    num_objects: usize
}

impl BruteForceSolver {
    pub fn new(sols: mp::MPComponentSolutions, num_objects: usize) -> Self {
        Self {sols, num_objects}
    }
}

impl MPSolver for BruteForceSolver {
    fn solve(self) -> mp::MPSolution {
        let sols = self.into_iter().collect();
        mp::MPSolution::from_vec(sols)
    }
}

impl IntoIterator for BruteForceSolver {
    type Item = mp::MPSolutionElem;
    type IntoIter = BruteForceSolverIterator2;

    fn into_iter(self) -> Self::IntoIter {
        BruteForceSolverIterator2::new(self)
    }
}

/// The Brute Force Solver Iterator Object
/// Uses an index based methods to ensure that no references are required (to satisfy pyo3 requirmeents)
pub struct BruteForceSolverIterator2 {
    solver: BruteForceSolver,
    index: Vec<usize>,
    max_index: (usize, usize, usize)
}
impl BruteForceSolverIterator2 {
    pub fn new(solver: BruteForceSolver) -> Self{
        let max_index = (
            solver.sols.xsol.solutions.len(),
            solver.sols.ysol.solutions.len(),
            solver.sols.zsol.solutions.len());
        Self {solver, index: vec![0, 0, 0], max_index}
    }
}
impl Iterator for BruteForceSolverIterator2 {
    type Item = mp::MPSolutionElem;
    /// Generates the next brute force solution
    fn next(&mut self) -> Option<Self::Item> {
        loop {

            if self.index[0] >= self.max_index.0 - 1
                && self.index[1] >= self.max_index.1 -1
                && self.index[2] >= self.max_index.2 -1 {
                return None
            }

            let x = &self.solver.sols.xsol.solutions[self.index[0]];
            let y = &self.solver.sols.ysol.solutions[self.index[1]];
            let z = &self.solver.sols.zsol.solutions[self.index[2]];

            // Manually perform product indexing
            if self.index[2] < self.max_index.2 - 1 {
                self.index[2]+=1;
            } else if self.index[1] < self.max_index.1 -1 {
                self.index[2] = 0;
                self.index[1] += 1;
            } else if self.index[0] < self.max_index.0 -1 {
                self.index[2] = 0;
                self.index[1] = 0;
                self.index[0] += 1;
            }

            let mut miss_all_objects: bool = true;
            for obs_id in 0..self.solver.num_objects {
                if x.intersects[obs_id].flag == objects::IntersectionFlag::None ||
                    y.intersects[obs_id].flag == objects::IntersectionFlag::None ||
                    z.intersects[obs_id].flag == objects::IntersectionFlag::None {
                    continue;
                } else {
                    if mp::utils::intersection_test_3(
                        &x.intersects[obs_id].points.to_array(),
                        &y.intersects[obs_id].points.to_array(),
                        &z.intersects[obs_id].points.to_array()) {
                        miss_all_objects = false; // True therefore collision
                        break;
                    }
                }
            }
            if miss_all_objects {
                let mse = mp::MPSolutionElem::from_component_solution(x, y, z);
                return Some(mse);
            }
        }
    }
}


/// Brute Force Solver Structure and iterator
#[derive(Clone)]
pub struct HeuristicSolver {
    sols: mp::MPComponentSolutions,
    num_objects: usize
}

impl HeuristicSolver {
    pub fn new(sols: mp::MPComponentSolutions, num_objects: usize) -> Self {
        Self {sols, num_objects}
    }
}

impl MPSolver for HeuristicSolver {
    fn solve(self) -> mp::MPSolution {
        let sols = self.into_iter().collect();
        mp::MPSolution::from_vec(sols)
    }
}

impl IntoIterator for HeuristicSolver {
    type Item = mp::MPSolutionElem;
    type IntoIter = HeuristicIterator;

    fn into_iter(self) -> Self::IntoIter {
        HeuristicIterator::new(self)
    }
}

/// The Heuristic Solver Iterator Object
/// Uses an index based methods to ensure that no references are required (to satisfy pyo3 requirmeents)
/// Slightly different from the matlab implementation
pub struct HeuristicIterator {
    num_objects: usize,
    csols: Vec<mp::MPComponentSolution>,
    index: Vec<usize>,
    max_index: (usize, usize, usize)
}
impl HeuristicIterator {
    pub fn new(solver: HeuristicSolver) -> Self{
        let max_index = (
            solver.sols.xsol.solutions.len(),
            solver.sols.ysol.solutions.len(),
            solver.sols.zsol.solutions.len());
        let mut csols = vec![solver.sols.xsol, solver.sols.ysol, solver.sols.zsol];
        csols.sort_by_key(|x| x.solutions.len());
        Self {num_objects: solver.num_objects, csols, index: vec![0, 0, 0], max_index}
    }
}
impl Iterator for HeuristicIterator {
    type Item = mp::MPSolutionElem;
    /// Generates the next heuristic solution
    fn next(&mut self) -> Option<Self::Item> {
        loop {

            // Perform rollover of previous variable if hits max_index
            // If z hits the limit, increase y by one
            if self.index[2] >= self.max_index.2 {
                self.index[2] = 0;
                self.index[1] += 1;
            }

            // If y hits the limit, set y and z to start and increase x by one
            if self.index[1] >= self.max_index.1 {
                self.index[2] = 0;
                self.index[1] = 0;
                self.index[0] += 1;
            }

            // If x hits the limit, no more values remain, terminate iterator
            if self.index[0] >= self.max_index.0 {
                return None
            }


            // Get Components
            let x = &self.csols[0].solutions[self.index[0]];
            let y = &self.csols[1].solutions[self.index[1]];
            let z = &self.csols[2].solutions[self.index[2]];

            let mut hit = false;

            // Begin heuristic
            let y_row = x.get_all_flags_as_array() * y.get_all_flags_as_array();

            if y_row.iter().all(|x|*x==0) { //case 3a
                hit = true;
            } else if y_row.iter().any(|x|*x==1 || *x==2) { //case 3b
                let z_row = &y_row * z.get_all_flags_as_array();
                if z_row.iter().all(|z|*z==0) { // case 4a
                    hit = true;
                } else if z_row.iter().any(|x|*x > 2) && !z_row.iter().any(|x|*x==1 || *x==2) { // case 4c
                    if !mp::utils::intersection_test_3_component_solution_elems(&x, &y, &z, self.num_objects) {
                        hit = true;
                    }
                }
            } else if y_row.iter().any(|x|*x > 2) && !y_row.iter().any(|x|*x==1 || *x==2) { // case 3c
                if !mp::utils::intersection_test_3_component_solution_elems(&x, &y, &z, self.num_objects){ // case 4b
                    hit = true;
                }
            }

            // Always update the z index (it will automatically roll over if neccesary)
            self.index[2] += 1;
            if hit {
                return Some(mp::MPSolutionElem::from_component_solution(x, y, z));
            }
        }
    }
}