/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

pub mod polar;
pub mod spherical;

////////////////////////////////////////////////////////////////////////////////

use nalgebra::constraint::{SameNumberOfRows, ShapeConstraint};
use nalgebra::storage::Storage;
use nalgebra::{
  Const, Dim, SVector, Scalar, SimdBool, SimdRealField, SliceStorage, Vector,
};
use rand::distributions::{uniform::SampleUniform, Distribution};
use rand::{thread_rng, SeedableRng};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::error::{InvalidParamError, Result};
use crate::params::FromParams;
use crate::rng::{LinearCoordinates, RNG};
use crate::trajectories::{FullTraj, LinearTrajectory};
use crate::util::{bounds::Bounds, norm::NormCost};

use super::CSpace;

/// Mosly a marker, but implementor are free to override and implement differently
///
/// Default implementation has leader in top D spots and follower in bottom D slots
/// and D*2 = N assertions
pub trait LeaderFollowerCSpace<X, const D: usize, const N: usize>:
  CSpace<X, N>
{
  /// Get the position of the leader within the dimensional space
  fn get_leader<S: Storage<X, Const<N>>>(
    state: &Vector<X, Const<N>, S>,
  ) -> Vector<
    X,
    Const<D>,
    SliceStorage<'_, X, Const<D>, Const<1_usize>, S::RStride, S::CStride>,
  > {
    debug_assert_eq!(D * 2, N);
    state.fixed_rows::<D>(0)
  }

  /// Get the position of the follower within the dimensional space
  fn get_follower<S: Storage<X, Const<N>>>(
    state: &Vector<X, Const<N>, S>,
  ) -> Vector<
    X,
    Const<D>,
    SliceStorage<'_, X, Const<D>, Const<1_usize>, S::RStride, S::CStride>,
  > {
    debug_assert_eq!(D * 2, N);
    state.fixed_rows::<D>(D)
  }

  /// Get the position of the follower within the dimensional space
  fn get_state<S1, S2>(
    leader: &Vector<X, Const<D>, S1>,
    follower: &Vector<X, Const<D>, S2>,
  ) -> SVector<X, N>
  where
    X: Scalar,
    S1: Storage<X, Const<D>>,
    S2: Storage<X, Const<D>>,
  {
    debug_assert_eq!(D * 2, N);
    SVector::<X, N>::from_iterator(
      leader.iter().chain(follower.iter()).cloned(),
    )
  }
}

#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
#[serde(bound(
  serialize = "X: Scalar + Serialize",
  deserialize = "X: Scalar + DeserializeOwned"
))]
pub struct LeaderFollowerSpaceParams<X, const N: usize> {
  pub bounds: Bounds<X, N>,
  pub norm: NormCost,
  pub seed: Option<u64>,
  pub sensor_range: (X, X),
}

impl<X: Scalar + PartialEq, const N: usize> PartialEq
  for LeaderFollowerSpaceParams<X, N>
{
  fn eq(&self, other: &Self) -> bool {
    self.bounds == other.bounds
      && self.norm == other.norm
      && self.seed == other.seed
      && self.sensor_range == other.sensor_range
  }
}

/// A uniformly sampled Cuboid with normed cost function
///
/// N must be D*2
pub struct LeaderFollowerSpace<X, const D: usize, const N: usize>
where
  X: SampleUniform,
{
  volume: X,
  norm: NormCost,
  rng: RNG,
  distribution: LinearCoordinates<X, N>,
  intial_sensor_range: (X, X),
  sensor_range: (X, X),
}

impl<X, const D: usize, const N: usize> LeaderFollowerSpace<X, D, N>
where
  X: SimdRealField + SampleUniform + Copy,
{
  pub fn new(
    bounds: Bounds<X::Element, N>,
    norm: NormCost,
    rng: RNG,
    sensor_range: (X::Element, X::Element),
  ) -> Result<Self>
  where
    X::Element: Scalar,
  {
    debug_assert_eq!(D * 2, N);

    let bounds = Bounds::splat(bounds);
    let sensor_range = (X::splat(sensor_range.0), X::splat(sensor_range.1));

    if !bounds.is_valid() {
      Err(InvalidParamError {
        parameter_name: "bounds",
        parameter_value: format!("{:?}", bounds),
      })?;
    }

    let volume = bounds.volume();
    let distribution = bounds.into();

    if !sensor_range.0.simd_lt(sensor_range.1).all() {
      Err(InvalidParamError {
        parameter_name: "sensor_range",
        parameter_value: format!("{:?}", sensor_range),
      })?;
    }

    Ok(Self {
      volume,
      norm,
      rng,
      distribution,
      intial_sensor_range: sensor_range,
      sensor_range,
    })
  }

  pub fn intial_sensor_range(&self) -> (X::Element, X::Element) {
    (
      self.intial_sensor_range.0.extract(0),
      self.intial_sensor_range.1.extract(0),
    )
  }

  pub fn get_sensor_range(&self) -> (X::Element, X::Element) {
    (
      self.sensor_range.0.extract(0),
      self.sensor_range.1.extract(0),
    )
  }

  pub fn set_sensor_range(
    &mut self,
    sensor_range: (X::Element, X::Element),
  ) -> Option<()> {
    let sensor_range = (X::splat(sensor_range.0), X::splat(sensor_range.1));
    if sensor_range.0.simd_lt(sensor_range.1).all() {
      // Valid
      self.sensor_range = sensor_range;
      Some(())
    } else {
      // Invalid value
      None
    }
  }
}

impl<X, const D: usize, const N: usize> LeaderFollowerCSpace<X, D, N>
  for LeaderFollowerSpace<X, D, N>
where
  X: SimdRealField + SampleUniform + Copy,
  X::Element: Scalar,
{
  // Default implementation okay
}

impl<X, const D: usize, const N: usize> CSpace<X, N>
  for LeaderFollowerSpace<X, D, N>
where
  X: SimdRealField + SampleUniform + Copy,
  X::Element: Scalar,
{
  type Traj = LinearTrajectory<X, NormCost, N>;

  fn volume(&self) -> X {
    self.volume
  }

  fn cost<R1, R2, S1, S2>(
    &self,
    a: &Vector<X, R1, S1>,
    b: &Vector<X, R2, S2>,
  ) -> X
  where
    R1: Dim,
    R2: Dim,
    S1: Storage<X, R1>,
    S2: Storage<X, R2>,
    ShapeConstraint: SameNumberOfRows<R1, R2>,
  {
    self.norm.cost(a, b)
  }

  fn trajectory<S1, S2>(
    &self,
    start: Vector<X, Const<N>, S1>,
    end: Vector<X, Const<N>, S2>,
  ) -> Option<FullTraj<X, Self::Traj, S1, S2, N>>
  where
    X: Scalar,
    S1: Storage<X, Const<N>>,
    S2: Storage<X, Const<N>>,
  {
    // Assumes that points are only sampled and saturated in this space
    Some(FullTraj::new(start, end, LinearTrajectory::new(self.norm)))
  }

  fn is_free<S>(&self, _: &Vector<X, Const<N>, S>) -> bool
  where
    S: Storage<X, Const<N>>,
  {
    // Assumes that points are only sampled and saturated in this space
    true
  }

  fn saturate(&self, a: &mut SVector<X, N>, b: &SVector<X, N>, delta: X) {
    let delta = delta / (X::one() + X::one());

    // Saturate leader to be delta away
    let mut leader_a = a.fixed_rows_mut::<D>(0);
    let leader_b = b.fixed_rows::<D>(0);

    let leader_scale = delta / self.norm.cost(&leader_a, &leader_b);

    leader_a.set_column(0, &(&leader_a - &leader_b));
    leader_a.set_column(0, &(&leader_a * leader_scale));
    leader_a.set_column(0, &(&leader_a + &leader_b));

    // Saturate follower to be delta away
    let mut follower_a = a.fixed_rows_mut::<D>(D);
    let follower_b = b.fixed_rows::<D>(D);

    let follower_scale = delta / self.norm.cost(&follower_a, &follower_b);

    follower_a.set_column(0, &(&follower_a - &follower_b));
    follower_a.set_column(0, &(&follower_a * follower_scale));
    follower_a.set_column(0, &(&follower_a + &follower_b));
  }

  fn sample(&mut self) -> SVector<X, N> {
    self.distribution.sample(&mut self.rng)
  }
}

impl<X, const D: usize, const N: usize> FromParams
  for LeaderFollowerSpace<X, D, N>
where
  X: SimdRealField + SampleUniform + Copy,
  X::Element: Scalar,
{
  type Params = LeaderFollowerSpaceParams<X::Element, N>;
  fn from_params(params: Self::Params) -> Result<Self> {
    let rng = match params.seed {
      Some(seed) => RNG::seed_from_u64(seed),
      None => RNG::from_rng(thread_rng())?,
    };

    LeaderFollowerSpace::new(
      params.bounds,
      params.norm,
      rng,
      params.sensor_range,
    )
  }
}
