/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

use nalgebra::constraint::{SameNumberOfRows, ShapeConstraint};
use nalgebra::storage::Storage;
use nalgebra::{Const, Dim, SVector, Vector};
use rand::distributions::{uniform::SampleUniform, Distribution};
use rand::{thread_rng, SeedableRng};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::cspace::CSpace;
use crate::error::{InvalidParamError, Result};
use crate::params::FromParams;
use crate::rng::{LinearCoordinates, RNG};
use crate::scalar::Scalar;
use crate::trajectories::{FullTraj, LinearTrajectory};
use crate::util::{bounds::Bounds, norm::NormCost};

#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(bound(
  serialize = "X: Serialize",
  deserialize = "X: DeserializeOwned"
))]
pub struct LinearSpaceParams<X: Scalar, const N: usize> {
  pub bounds: Bounds<X, N>,
  pub norm: NormCost,
  pub seed: Option<u64>,
}

/// A uniformly sampled Cuboid with normed cost function
pub struct LinearSpace<X, const N: usize>
where
  X: SampleUniform,
{
  volume: X,
  norm: NormCost,
  rng: RNG,
  distribution: LinearCoordinates<X, N>,
}

impl<X, const N: usize> LinearSpace<X, N>
where
  X: Scalar + SampleUniform,
{
  pub fn new(bounds: Bounds<X, N>, norm: NormCost, rng: RNG) -> Result<Self> {
    if !bounds.is_valid() {
      Err(InvalidParamError {
        parameter_name: "bounds",
        parameter_value: format!("{:?}", bounds),
      })?;
    }

    let volume = bounds.volume();
    let distribution = bounds.into();

    Ok(Self {
      volume,
      norm,
      rng,
      distribution,
    })
  }
}

impl<X, const N: usize> CSpace<X, N> for LinearSpace<X, N>
where
  X: Scalar + SampleUniform,
{
  type Traj = LinearTrajectory<X, NormCost, N>;

  fn volume(&self) -> X {
    self.volume
  }

  fn cost<R1, R2, S1, S2>(
    &self,
    a: &Vector<X, R1, S1>,
    b: &Vector<X, R2, S2>,
  ) -> X
  where
    X: Scalar,
    R1: Dim,
    R2: Dim,
    S1: Storage<X, R1>,
    S2: Storage<X, R2>,
    ShapeConstraint: SameNumberOfRows<R1, R2>
      + SameNumberOfRows<R1, Const<N>>
      + SameNumberOfRows<R2, Const<N>>,
  {
    self.norm.cost(a, b)
  }

  fn trajectory<S1, S2>(
    &self,
    start: Vector<X, Const<N>, S1>,
    end: Vector<X, Const<N>, S2>,
  ) -> Option<FullTraj<X, Self::Traj, S1, S2, N>>
  where
    X: Scalar,
    S1: Storage<X, Const<N>>,
    S2: Storage<X, Const<N>>,
  {
    // Assumes that points are only sampled and saturated in this space
    Some(FullTraj::new(start, end, LinearTrajectory::new(self.norm)))
  }

  fn is_free<S>(&self, _: &Vector<X, Const<N>, S>) -> bool
  where
    S: Storage<X, Const<N>>,
  {
    // Assumes that points are only sampled and saturated in this space
    true
  }

  fn saturate(&self, a: &mut SVector<X, N>, b: &SVector<X, N>, delta: X) {
    let scale = delta / self.norm.cost(a, b);
    *a -= *b;
    *a *= scale;
    *a += *b;
  }

  fn sample(&mut self) -> SVector<X, N> {
    self.distribution.sample(&mut self.rng)
  }
}

impl<X, const N: usize> FromParams for LinearSpace<X, N>
where
  X: Scalar + SampleUniform,
{
  type Params = LinearSpaceParams<X, N>;
  fn from_params(params: Self::Params) -> Result<Self> {
    let rng = match params.seed {
      Some(seed) => RNG::seed_from_u64(seed),
      None => RNG::from_rng(thread_rng())?,
    };
    LinearSpace::new(params.bounds, params.norm, rng)
  }
}

#[cfg(test)]
mod tests {

  use rand::SeedableRng;
  use rayon::iter::{IntoParallelIterator, ParallelIterator};

  use super::*;

  const SEED: u64 = 0xe580e2e93fd6b040;

  #[test]
  fn test_parallel_sample() {
    let mins: [f32; 2] = [-2.0, -2.0];
    let maxs = [2.0, 2.0];
    let rng = RNG::seed_from_u64(SEED);

    let bounds = Bounds::new(mins.into(), maxs.into());
    let space = LinearSpace::new(bounds, NormCost::TwoNorm, rng).unwrap();

    let samples = (0..1000)
      .into_par_iter()
      .map(|_| {
        let point = [-1.0, -1.0].into();
        space.is_free(&point) // parallel read access to config space
      })
      .collect::<Vec<_>>();

    assert_eq!(samples.len(), 1000);
    let point = [-1.0, -1.0].into();
    assert_eq!(space.is_free(&point), true);
  }
}
