/* Copyright (C) 2020 Dylan Staatz - All Rights Reserved. */

use nalgebra::constraint::{SameNumberOfRows, ShapeConstraint};
use nalgebra::storage::Storage;
use nalgebra::{Const, Dim, RealField, SVector, Vector};
use num_traits::Float;
use rand::distributions::{uniform::SampleUniform, Distribution};
use rand::{thread_rng, SeedableRng};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::error::{InvalidParamError, Result};
use crate::params::FromParams;
use crate::rng::{LinearCoordinates, RNG};
use crate::scalar::Scalar;
use crate::trajectories::{EuclideanTrajectory, FullTraj};
use crate::util::bounds::Bounds;
use crate::util::math::{atan2, unit_d_ball_vol};

use super::super::CSpace;
use super::LeaderFollowerCSpace;

pub const D: usize = 2;
pub const N: usize = D * 2;

#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(bound(
  serialize = "X: Serialize",
  deserialize = "X: DeserializeOwned"
))]
pub struct LeaderFollowerPolarSpaceParams<X: Scalar> {
  pub bounds: Bounds<X, D>,
  pub seed: Option<u64>,
  pub sensor_range: (X, X),
}

/// A space with the state vector of [x1, y1, r2, theta2]
///
/// x1: x-coordinate of the leader robot
/// y1: y-coordinate of the leader robot
/// r2: radius polar coordinate offset of the follower relative to the leader
/// theta2: angle off of x-axis polar coordinate offset of the follower relative
/// to the leader
pub struct LeaderFollowerPolarSpace<X>
where
  X: Scalar + SampleUniform,
{
  bounds: Bounds<X, D>,
  volume: X,
  rng: RNG,
  distribution: LinearCoordinates<X, N>,
  intial_sensor_range: (X, X),
  sensor_range: (X, X),
  sensor_range_squared: (X, X),
}

impl<X> LeaderFollowerPolarSpace<X>
where
  X: Scalar + SampleUniform,
{
  pub fn new(
    bounds: Bounds<X, D>,
    rng: RNG,
    sensor_range: (X, X),
  ) -> Result<Self> {
    debug_assert_eq!(D * 2, N);

    // Validate bounds
    if !bounds.is_valid() {
      Err(InvalidParamError {
        parameter_name: "bounds",
        parameter_value: format!("{:?}", bounds),
      })?;
    }

    // Validate sensor range
    if !(sensor_range.0 < sensor_range.1) {
      Err(InvalidParamError {
        parameter_name: "sensor_range",
        parameter_value: format!("{:?}", sensor_range),
      })?;
    }

    let sensor_range_squared = (
      sensor_range.0 * sensor_range.0,
      sensor_range.1 * sensor_range.1,
    );

    let sensor_space_volume = (sensor_range_squared.1 * unit_d_ball_vol(D))
      - (sensor_range_squared.0 * unit_d_ball_vol(D));
    let volume = bounds.volume() * sensor_space_volume;

    let mins = SVector::<X, N>::from([
      bounds.mins[0],
      bounds.mins[1],
      X::zero(),
      X::zero(),
    ]);

    let maxs = SVector::<X, N>::from([
      bounds.maxs[0],
      bounds.maxs[1],
      X::one(),
      X::two_pi(),
    ]);

    let distribution = LinearCoordinates::new(mins, maxs);

    Ok(Self {
      bounds,
      volume,
      rng,
      distribution,
      intial_sensor_range: sensor_range,
      sensor_range,
      sensor_range_squared,
    })
  }

  pub fn intial_sensor_range(&self) -> (X, X) {
    self.intial_sensor_range
  }

  pub fn get_sensor_range(&self) -> (X, X) {
    self.sensor_range
  }

  pub fn set_sensor_range(&mut self, sensor_range: (X, X)) -> Option<()> {
    if sensor_range.0 < sensor_range.1 {
      // Valid
      self.sensor_range = sensor_range;
      self.sensor_range_squared = (
        sensor_range.0 * sensor_range.0,
        sensor_range.1 * sensor_range.1,
      );
      Some(())
    } else {
      // Invalid value
      None
    }
  }
}

impl<X> LeaderFollowerCSpace<X, D, N> for LeaderFollowerPolarSpace<X>
where
  X: Scalar + SampleUniform,
{
  // Default implementation okay
}

impl<X> CSpace<X, N> for LeaderFollowerPolarSpace<X>
where
  X: Scalar + SampleUniform,
{
  type Traj = EuclideanTrajectory<X, N>;

  fn volume(&self) -> X {
    self.volume
  }

  fn cost<R1, R2, S1, S2>(
    &self,
    a: &Vector<X, R1, S1>,
    b: &Vector<X, R2, S2>,
  ) -> X
  where
    X: Scalar,
    R1: Dim,
    R2: Dim,
    S1: Storage<X, R1>,
    S2: Storage<X, R2>,
    ShapeConstraint: SameNumberOfRows<R1, R2>
      + SameNumberOfRows<R1, Const<N>>
      + SameNumberOfRows<R2, Const<N>>,
  {
    a.metric_distance(b)
  }

  fn trajectory<S1, S2>(
    &self,
    start: Vector<X, Const<N>, S1>,
    end: Vector<X, Const<N>, S2>,
  ) -> Option<FullTraj<X, Self::Traj, S1, S2, N>>
  where
    X: Scalar,
    S1: Storage<X, Const<N>>,
    S2: Storage<X, Const<N>>,
  {
    Some(FullTraj::new(start, end, EuclideanTrajectory::new()))
  }

  fn is_free<S>(&self, a: &Vector<X, Const<N>, S>) -> bool
  where
    S: Storage<X, Const<N>>,
  {
    let leader_abs = a.fixed_rows::<D>(0);
    let follower_abs = a.fixed_rows::<D>(D);

    if !self.bounds.within(&follower_abs) {
      return false;
    }

    // Determine distance between leader and follower
    let r2 = leader_abs.metric_distance(&follower_abs);
    self.sensor_range.0 <= r2 && r2 <= self.sensor_range.1
  }

  fn saturate(&self, a: &mut SVector<X, N>, b: &SVector<X, N>, delta: X) {
    // log::info!(
    //   "saturate: {:?}, {:?}, {:?}",
    //   <[X; N]>::from(a.clone_owned()),
    //   <[X; N]>::from(b.clone_owned()),
    //   delta
    // );

    let delta = delta / (X::one() + X::one());

    // Saturate leader to be delta away
    let mut a_lead_mut = a.fixed_rows_mut::<D>(0);
    let b_lead = b.fixed_rows::<D>(0);

    // println!("a_lead_mut: {:?}", <[X; D]>::from(a_lead_mut.clone_owned()));
    // println!("b_lead: {:?}", <[X; D]>::from(b_lead.clone_owned()));

    let lead_scale = delta / a_lead_mut.metric_distance(&b_lead);

    a_lead_mut.set_column(0, &(&a_lead_mut - &b_lead));
    a_lead_mut.set_column(0, &(&a_lead_mut * lead_scale));
    a_lead_mut.set_column(0, &(&a_lead_mut + &b_lead));

    // println!("a_lead_mut: {:?}", <[X; D]>::from(a_lead_mut.clone_owned()));
    // println!();

    // Saturate follower to be delta away
    let a_lead = a.fixed_rows::<D>(0);
    let a_fol = a.fixed_rows::<D>(D);

    // println!("a_lead: {:?}", <[X; D]>::from(a_lead.clone_owned()));
    // println!("a_fol: {:?}", <[X; D]>::from(a_fol.clone_owned()));

    // Modifiy b follower to be an offset of the a_lead
    let mut b = b.clone(); // Local copy of b for temporary modifications
    let mut b_lead_mut = b.fixed_rows_mut::<D>(0);
    b_lead_mut.set_column(0, &a_lead); // A leader, B follower

    // println!("b: {:?}", <[X; N]>::from(b.clone_owned()));

    let mut b_rel = abs_to_rel(&b); // B follower relative to new A leader

    // println!("b_rel: {:?}", <[X; N]>::from(b_rel.clone_owned()));

    // Narrow down to polar coordinates only
    let mut b_fol_rel_mut = b_rel.fixed_rows_mut::<D>(D).into_owned();
    let a_fol_rel = cartesian_to_polar(&a_fol);

    // Bounds
    let mut delta = delta;
    if self.sensor_range.1 < b_fol_rel_mut[0] {
      delta -= b_fol_rel_mut[0] - self.sensor_range.1;
      b_fol_rel_mut[0] = self.sensor_range.1;
    }
    if b_fol_rel_mut[0] < self.sensor_range.0 {
      delta -= self.sensor_range.0 - b_fol_rel_mut[0];
      b_fol_rel_mut[0] = self.sensor_range.0;
    }

    if X::zero() < delta {
      let mut star_fol_rel = saturate_polar(&a_fol_rel, &b_fol_rel_mut, delta);

      log::debug!(
        "sensor_range: {:?}, star_fol_rel: {:?}",
        self.sensor_range,
        <[X; D]>::from(star_fol_rel)
      );

      // bound rho
      if self.sensor_range.1 < star_fol_rel[0] {
        log::warn!("rho too big: [{}, {}]", star_fol_rel[0], star_fol_rel[1]);
        star_fol_rel[0] = self.sensor_range.1;
      }
      if star_fol_rel[0] < self.sensor_range.0 {
        log::warn!("rho too small: [{}, {}]", star_fol_rel[0], star_fol_rel[1]);
        star_fol_rel[0] = self.sensor_range.0;
      }

      // Set a in relative polar coordinates
      let mut a_fol_mut = a.fixed_rows_mut::<D>(D);
      a_fol_mut.set_column(0, &star_fol_rel);
    } else {
      // Set a in relative polar coordinates
      let mut a_fol_mut = a.fixed_rows_mut::<D>(D);
      a_fol_mut.set_column(0, &b_fol_rel_mut);
    }

    // Convert relative polar coordinates back to absolute
    *a = rel_to_abs(&a);
  }

  fn sample(&mut self) -> SVector<X, N> {
    let mut sample = self.distribution.sample(&mut self.rng);

    let a_2 = self.sensor_range_squared.0;
    let b_2 = self.sensor_range_squared.1;
    sample[2] = <X as Float>::sqrt(sample[2] * (b_2 - a_2) + a_2);

    rel_to_abs(&sample)
  }
}

impl<X> FromParams for LeaderFollowerPolarSpace<X>
where
  X: Scalar + SampleUniform,
{
  type Params = LeaderFollowerPolarSpaceParams<X>;
  fn from_params(params: Self::Params) -> Result<Self> {
    let rng = match params.seed {
      Some(seed) => RNG::seed_from_u64(seed),
      None => RNG::from_rng(thread_rng())?,
    };

    LeaderFollowerPolarSpace::new(params.bounds, rng, params.sensor_range)
  }
}

/// Converts [x1, y1, r2, theta2] to [x1, y1, x2, y2]
///
/// Inverse of [`abs_to_rel`]
pub fn rel_to_abs<X, R, S>(v: &Vector<X, R, S>) -> SVector<X, N>
where
  X: RealField + Copy,
  R: Dim,
  S: Storage<X, R>,
  ShapeConstraint: SameNumberOfRows<R, Const<N>>,
{
  let x1 = v[0];
  let y1 = v[1];
  let x2 = v[0] + (v[2] * v[3].cos());
  let y2 = v[1] + (v[2] * v[3].sin());

  [x1, y1, x2, y2].into()
}

/// Converts [r, theta] to [x, y]
pub fn polar_to_cartesian<X, R, S>(v: &Vector<X, R, S>) -> SVector<X, D>
where
  X: RealField + Copy,
  R: Dim,
  S: Storage<X, R>,
  ShapeConstraint: SameNumberOfRows<R, Const<D>>,
{
  let x = v[0] * v[1].cos();
  let y = v[0] * v[1].sin();
  [x, y].into()
}

/// Converts [x1, y1, x2, y2] to [x1, y1, r2, theta2]
///
/// Inverse of [`rel_to_abs`]
pub fn abs_to_rel<X, R, S>(v: &Vector<X, R, S>) -> SVector<X, N>
where
  X: RealField + Copy,
  R: Dim,
  S: Storage<X, R>,
  ShapeConstraint: SameNumberOfRows<R, Const<N>>,
{
  let x1 = v[0];
  let y1 = v[1];
  let x2 = v[2] - x1; // offset cartesian
  let y2 = v[3] - y1; // offset cartesian

  let r2 = (x2 * x2 + y2 * y2).sqrt();
  let theta2 = atan2(y2, x2).unwrap_or(X::zero());

  [x1, y1, r2, bound_theta(theta2)].into()
}

/// Converts [x, y] to [r, theta]
pub fn cartesian_to_polar<X, R, S>(v: &Vector<X, R, S>) -> SVector<X, D>
where
  X: RealField + Copy,
  R: Dim,
  S: Storage<X, R>,
  ShapeConstraint: SameNumberOfRows<R, Const<D>>,
{
  let r = (v[0] * v[0] + v[1] * v[1]).sqrt();
  let theta = atan2(v[1], v[0]).unwrap_or(X::zero());
  [r, bound_theta(theta)].into()
}

fn saturate_polar<X, R1, S1, R2, S2>(
  a: &Vector<X, R1, S1>,
  b: &Vector<X, R2, S2>,
  delta: X,
) -> SVector<X, D>
where
  X: RealField + Copy,
  R1: Dim,
  R2: Dim,
  S1: Storage<X, R1>,
  S2: Storage<X, R2>,
  ShapeConstraint: SameNumberOfRows<R1, R2>
    + SameNumberOfRows<R1, Const<D>>
    + SameNumberOfRows<R2, Const<D>>,
{
  let r_a = a[0];
  let theta_a = a[1];
  let r_b = b[0];
  let theta_b = b[1];

  let (r_star, theta_star) =
    saturate_polar_zero(r_a, bound_theta(theta_a - theta_b), r_b, delta);
  [r_star, bound_theta(theta_star + theta_b)].into()
}

/// Assumes theta_b is 0
pub fn saturate_polar_zero<X>(r_a: X, theta_a: X, r_b: X, delta: X) -> (X, X)
where
  X: RealField + Copy,
{
  let two = X::one() + X::one();

  log::debug!("delta: {:?}", delta);

  // Assert the correct bounds
  assert!(X::zero() < delta);
  assert!(X::zero() < r_a);
  assert!(X::zero() < r_b);
  assert!(X::zero() <= theta_a && theta_a < X::two_pi());

  // Special case when point a and b are at the same angle
  if theta_a == X::zero() {
    let r_star = match r_b - r_a {
      x if x > delta => r_a + delta,
      x if -delta < x => r_a - delta,
      _ => r_a,
    };
    return (r_star, X::zero());
  }

  // Determine which direction we should rotate from theta_b => theta_star => theta_a
  let clockwise = X::pi() < theta_a;
  log::debug!("clockwise: {:?}", clockwise);

  let (theta_a, delta) = if clockwise {
    (theta_a - X::two_pi(), -delta)
  } else {
    (theta_a, delta)
  };

  // theta_a is now on the domain [-pi, pi] still representing the same angle

  // Determine spiral function
  let m = (r_a - r_b) / theta_a;
  let m_2 = m.powi(2);
  let r_0 = r_b; // should be -m * theta_b + r_b, but theta_b is assumed to be 0
  let r = |theta: X| m * theta + r_0;

  // Indefinite Integral of spiral distance function
  let integral = |theta: X| {
    let r_theta = r(theta);
    let temp = (r_theta.powi(2) + m_2).sqrt();

    let term1 = (r_theta * temp) / (two * m);
    let term2 = (m / two) * (temp + r_theta).ln();
    term1 + term2
  };

  // The function the find the roots of
  let constant = -integral(X::zero()) - delta;
  let f = |theta: X| integral(theta) + constant;

  // The derivative of the function we are finding the roots of
  // (i.e. the spiral distance function)
  let fd = |theta: X| (r(theta).powi(2) + m_2).sqrt();

  // Newtons method
  let mut theta_star = delta * two / (r_a + r_b);
  let iterations = 2;
  for _ in 0..iterations {
    theta_star -= f(theta_star) / fd(theta_star);
  }

  // Bound theta_star to be between theta_a and theta_b (0.0)
  if clockwise {
    assert!(theta_a < X::zero());
    theta_star = theta_star.max(theta_a);
    theta_star = theta_star.min(X::zero());
  } else {
    assert!(X::zero() < theta_a);
    theta_star = theta_star.max(X::zero());
    theta_star = theta_star.min(theta_a);
  }

  let r_star = r(theta_star);

  // Return and bound theta_star
  (r_star, bound_theta(theta_star))
}

fn bound_theta<X>(mut theta: X) -> X
where
  X: RealField + Copy,
{
  // Lower bound (Inclusive)
  while theta < X::zero() {
    theta += X::two_pi()
  }
  // Upper bound (Exclusive)
  while X::two_pi() <= theta {
    theta -= X::two_pi()
  }
  theta
}

#[cfg(test)]
mod tests {

  use super::*;

  fn rel_eq<X: RealField + Copy>(a: X, b: X) -> bool {
    let abs_difference = (a - b).abs();
    println!("{:?}", abs_difference);
    abs_difference < X::from_subset(&10.0).powi(-3)
  }

  #[test]
  fn test_polar_saturate() {
    let params = LeaderFollowerPolarSpaceParams {
      bounds: Bounds::new([-5.0, -5.0].into(), [5.0, 5.0].into()),
      seed: Some(0x1234_5678_1234_5674),
      sensor_range: (0.1, 2.0),
    };

    let space = LeaderFollowerPolarSpace::from_params(params).unwrap();

    let mut a = [3.7375317, 0.9111872, 0.90473485, 4.3599744].into();
    let b = [3.7432036, -2.570983, 2.502023, -4.0045195].into();

    space.saturate(&mut a, &b, 1.0);

    println!("a_star: {:?}", <[f32; N]>::from(a));
  }

  #[test]
  fn test_saturate_polar_zero_counter_clockwise_f32() {
    let r_b: f32 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = 1.4744684;
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.893392428268));
    assert!(rel_eq(theta_star, 0.253));
  }

  #[test]
  fn test_saturate_polar_zero_counter_clockwise_f64() {
    let r_b: f64 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = 1.4744684;
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.893392428268));
    assert!(rel_eq(theta_star, 0.253));
  }

  #[test]
  fn test_saturate_polar_zero_clockwise_f32() {
    let r_b: f32 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = bound_theta(-1.4744684);
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.893392428268));
    assert!(rel_eq(theta_star, bound_theta(-0.253)));
  }

  #[test]
  fn test_saturate_polar_zero_clockwise_f64() {
    let r_b: f64 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = bound_theta(-1.4744684);
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.893392428268));
    assert!(rel_eq(theta_star, bound_theta(-0.253)));
  }

  #[test]
  fn test_saturate_polar_zero_counter_clockwise_long_f32() {
    let r_b: f32 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = 3.127344367;
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.820972361522));
    assert!(rel_eq(theta_star, 0.328));
  }

  #[test]
  fn test_saturate_polar_zero_counter_clockwise_long_f64() {
    let r_b: f64 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = 3.127344367;
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.820972361522));
    assert!(rel_eq(theta_star, 0.328));
  }

  #[test]
  fn test_saturate_polar_zero_clockwise_long_f32() {
    let r_b: f32 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = bound_theta(-3.127344367);
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.820972361522));
    assert!(rel_eq(theta_star, bound_theta(-0.328)));
  }

  #[test]
  fn test_saturate_polar_zero_clockwise_long_f64() {
    let r_b: f64 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = bound_theta(-3.127344367);
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.820972361522));
    assert!(rel_eq(theta_star, bound_theta(-0.328)));
  }

  #[test]
  fn test_saturate_polar_zero_counter_clockwise_short_f32() {
    let r_b: f32 = 0.70710677;
    let r_a = 1.7927681;
    let theta_a = 0.3253262;
    let _theta_b = 0.0;
    let delta = 0.275;

    let (r_star, theta_star) = saturate_polar_zero(r_a, theta_a, r_b, delta);

    assert!(rel_eq(r_star, 0.974078524504));
    assert!(rel_eq(theta_star, 0.080));
  }
}
